delayed_scaling.cu 15.7 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/recipe.h>

#include <cmath>
10
#include <limits>
11
#include <string>
12
13

#include "../common.h"
14
#include "../util/cuda_runtime.h"
15
#include "../util/logging.h"
16
17
18
19
20
21
22
23
24
25
26

namespace transformer_engine {
namespace delayed_scaling_recipe {

namespace {

// amax value to use for updating scaling factor
enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX };

const char* dtype_name(DType dtype) {
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type,
27
                                     return TypeInfo<Type>::name;);  // NOLINT(*)
28
29
30
31
32
33
  return "";
}

// Maximum representable value of an FP8 dtype
inline float fp8_dtype_max(DType dtype) {
  switch (dtype) {
34
    case DType::kFloat8E4M3:
yuguo's avatar
yuguo committed
35
#ifndef __HIP_PLATFORM_AMD__
36
      return 448;
yuguo's avatar
yuguo committed
37
38
39
#else
      return 240;
#endif
40
41
42
43
    case DType::kFloat8E5M2:
      return 57344;
    default:
      NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
44
45
46
47
  }
  return 0;
}

48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
// struct for amax parameters
struct AmaxParam {
  int num_scale = 0;
  float* amax_history = nullptr;
  float* scale = nullptr;
};

// dummy struct for kernel_bulk's other params
struct OtherParams {
  float* a;
  size_t b;
  AmaxComputeAlgo c;
  float d;
};

#if CUDART_VERSION >= 12010
64
constexpr size_t max_constant_memory_per_kernel = 32768;
65
66
constexpr size_t AMAX_PARAMS_LIMIT =
    (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
67
#else
68
constexpr size_t max_constant_memory_per_kernel = 4096;
69
70
constexpr size_t AMAX_PARAMS_LIMIT =
    (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
71
72
73
74
75
76
#endif

struct AmaxParams {
  AmaxParam param[AMAX_PARAMS_LIMIT];
};

77
78
79
80
81
82
83
84
85
86
87
88
namespace amax_and_scale_update_impl {

// CUDA block size
constexpr size_t bsize = 256;

/* CUDA kernel to update amax history and FP8 scaling factors
 *
 * Block dims: bsize x 1 x 1
 *
 * Grid dims: num_scales x 1 x 1
 */
__global__ void __launch_bounds__(bsize)
89
90
91
    kernel(const float* amax_history_ptr, const float* scale_ptr, float* updated_amax_history_ptr,
           float* updated_scale_ptr, size_t amax_history_length, size_t amax_history_stride,
           AmaxComputeAlgo amax_compute_algo, float scaled_max) {
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  const size_t tid = threadIdx.x;
  const size_t bid = blockIdx.x;

  // Update amax
  float amax = 0;
  {
    // Roll amax history
    const auto* amax_history = amax_history_ptr + bid;
    auto* updated_amax_history = updated_amax_history_ptr + bid;
    const auto last_amax = amax_history[0];
    const auto& length = amax_history_length;
    const auto& stride = amax_history_stride;
    for (size_t off = 0; off < length; off += bsize) {
      const size_t i = off + tid;
      float a = 0;
      if (i < length) {
108
        a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
109
110
111
112
        amax = fmaxf(amax, a);
      }
      __syncthreads();  // In case roll is in-place
      if (i < length) {
113
        updated_amax_history[i * stride] = (i > 0) ? a : 0;
114
115
116
117
118
      }
    }

    // Compute amax to use for scaling factor
    switch (amax_compute_algo) {
119
120
121
122
      case AmaxComputeAlgo::MOST_RECENT:
        amax = last_amax;
        break;
      case AmaxComputeAlgo::MAX: {
123
124
125
126
127
128
129
130
131
132
133
        __shared__ float shared_amax[bsize];
        shared_amax[tid] = amax;
        __syncthreads();
#pragma unroll
        for (size_t off = bsize / 2; off > 0; off /= 2) {
          if (tid < off) {
            shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
          }
          __syncthreads();
        }
        amax = shared_amax[tid];
134
135
136
      } break;
      default:
        amax = 0;
137
138
139
    }
  }

140
  // Update scale
141
142
143
144
145
146
147
148
  if (tid == 0) {
    // Update scale
    float scale;
    if (isfinite(amax) && amax > 0) {
      scale = scaled_max / amax;
    } else {
      scale = scale_ptr[bid];
    }
149
150
151
152
153
    // When the amax is too tiny that the scale becoming infinite in FP32,
    // we set the scale to the max value of FP32. In this case, the tensor’s
    // amax won't get mapped to the FP8 max representable, but rather
    // something below that, but this is the best thing we can do.
    if (isinf(scale)) {
154
      scale = std::numeric_limits<float>::max();
155
    }
156
157
158
159
    updated_scale_ptr[bid] = scale;
  }
}

160
161
162
163
164
165
166
/* CUDA kernel to bulk-update amax history and FP8 scaling factors
 *
 * Block dims: bsize x 1 x 1
 *
 * Grid dims: num_tensors x 1 x 1
 */
__global__ void __launch_bounds__(bsize)
167
168
    kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length,
                AmaxComputeAlgo amax_compute_algo, float scaled_max) {
169
170
171
172
173
174
175
176
  const size_t bid = blockIdx.x;
  const size_t tid = threadIdx.x;
  const int num_scale = p.param[bid].num_scale;

  int offset_in_buffer = 0;
  for (int j = 0; j < bid; j++) {
    offset_in_buffer += p.param[j].num_scale;
  }
177

178
179
180
181
182
183
184
  for (int count = 0; count < num_scale; count++) {
    // Update amax
    float amax = 0;
    {
      // Roll amax history
      const auto& length = amax_history_length;
      const auto& stride = p.param[bid].num_scale;
185
186
187
188
189
      auto* amax_history = p.param[bid].amax_history + count;
      const auto last_amax = ((amax_reduction_buffer != nullptr) &&
                              (amax_reduction_buffer[offset_in_buffer + count] != 0.0f))
                                 ? amax_reduction_buffer[offset_in_buffer + count]
                                 : amax_history[0];
190
191
192
193
194
      if (last_amax != 0.0f) {
        for (size_t off = 0; off < length; off += bsize) {
          const size_t i = off + tid;
          float a = 0;
          if (i < length) {
195
            a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
196
197
198
199
            amax = fmaxf(amax, a);
          }
          __syncthreads();  // Inplace roll
          if (i < length) {
200
            amax_history[i * stride] = (i > 0) ? a : 0;
201
          }
202
203
204
205
206
        }
      }

      // Compute amax to use for scaling factor
      switch (amax_compute_algo) {
207
208
209
210
        case AmaxComputeAlgo::MOST_RECENT:
          amax = last_amax;
          break;
        case AmaxComputeAlgo::MAX: {
211
212
213
214
215
216
217
218
219
220
221
          __shared__ float shared_amax[bsize];
          shared_amax[tid] = amax;
          __syncthreads();
#pragma unroll
          for (size_t off = bsize / 2; off > 0; off /= 2) {
            if (tid < off) {
              shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
            }
            __syncthreads();
          }
          amax = shared_amax[tid];
222
223
224
        } break;
        default:
          amax = 0;
225
226
227
      }
    }

228
    // Update scale
229
    if (tid == 0) {
230
231
232
233
234
235
236
237
238
239
240
      // Computing the scaling factor requires consideration of the following scenarios:
      // 1. amax == 0:
      //    No action is possible, set scale to the previous scale (or 1).
      // 2. 0 < amax < tiny_amax
      //    The amax is too tiny that the scale becomes infinite in FP32.
      //    Set scale = FP32_max
      // 3. tiny_amax <= amax < FP32_max:
      //    Set scale = FP8_max (or scaled_max) / amax
      // 4. When amax == inf or amax == nan:
      //    No action is possible, set scale to the previous scale (or 1).

241
242
243
244
245
246
      float scale;
      if (isfinite(amax) && amax > 0) {
        scale = scaled_max / amax;
      } else {
        scale = p.param[bid].scale[count];
      }
247
248
249
250
251
      // When the amax is too tiny that the scale becoming infinite in FP32,
      // we set the scale to the max value of FP32. In this case, the tensor’s
      // amax won't get mapped to the FP8 max representable, but rather
      // something below that, but this is the best thing we can do.
      if (isinf(scale)) {
252
        scale = std::numeric_limits<float>::max();
253
      }
254
255
256
257
258
259
      p.param[bid].scale[count] = scale;
    }
  }
}

}  // namespace amax_and_scale_update_impl
260
261
262

}  // namespace

263
264
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale,
                           Tensor* updated_amax_history_, Tensor* updated_scale_,
265
                           const std::string& amax_compute_algo, DType fp8_dtype, float margin,
266
267
268
269
270
                           cudaStream_t stream) {
  auto& updated_amax_history = *updated_amax_history_;
  auto& updated_scale = *updated_scale_;

  // Check tensors
271
272
  NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(),
             " dims");
273
274
  const size_t amax_history_length = amax_history.data.shape[0];
  const size_t num_scales = amax_history.data.shape[1];
275
276
  NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ",
             dtype_name(amax_history.data.dtype), ".");
277
278
  NVTE_CHECK(scale.numel() == num_scales, "Expected ", num_scales, " elements, ", "but found ",
             scale.numel(), ".");
279
280
281
282
283
284
  NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), ".");
  NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ",
             updated_amax_history.data.shape.size(), " dims.");
  NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ",
             amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]);
  NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ",
285
             "but found ", updated_amax_history.data.shape[1]);
286
287
  NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_amax_history.data.dtype), ".");
288
289
  NVTE_CHECK(updated_scale.numel() == num_scales, "Expected ", num_scales, " elements, ",
             "but found ", updated_scale.numel(), ".");
290
291
  NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_scale.data.dtype), ".");
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308

  // amax value to use for updating scaling factor
  AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
  if (amax_compute_algo == "max") {
    amax_compute_algo_ = AmaxComputeAlgo::MAX;
  } else if (amax_compute_algo == "most_recent") {
    amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
  } else {
    NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
  }

  // Expected maximum value after scale is applied
  const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);

  // Launch CUDA kernel
  constexpr size_t block_size = amax_and_scale_update_impl::bsize;
  const size_t grid_size = num_scales;
309
310
  amax_and_scale_update_impl::kernel<<<grid_size, block_size, 0, stream>>>(
      static_cast<const float*>(amax_history.data.dptr), static_cast<const float*>(scale.data.dptr),
311
      static_cast<float*>(updated_amax_history.data.dptr),
312
      static_cast<float*>(updated_scale.data.dptr), amax_history_length, num_scales,
313
      amax_compute_algo_, scaled_max);
314
315
316
  NVTE_CHECK_CUDA(cudaGetLastError());
}

317
void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
318
319
                                           std::vector<Tensor*> amax_histories,
                                           std::vector<Tensor*> scales,
320
321
                                           const std::string& amax_compute_algo, DType fp8_dtype,
                                           float margin, cudaStream_t stream) {
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
  using namespace transformer_engine;

  // amax value to use for updating scaling factor
  AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
  if (amax_compute_algo == "max") {
    amax_compute_algo_ = AmaxComputeAlgo::MAX;
  } else if (amax_compute_algo == "most_recent") {
    amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
  } else {
    NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
  }

  // Expected maximum value after scale is applied
  const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);

  // Number of tensors in the bulk
  const size_t num_tensors = amax_histories.size();
339
  size_t num_remaining_tensors = num_tensors;
340
  const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT;
341
342
343
344
345
346
347
348
349
350
  size_t amax_history_length = 0;
  if (num_tensors > 0) {
    amax_history_length = amax_histories[0]->data.shape[0];
  }

  // amax parameters
  float* amax_buffer = static_cast<float*>(amax_reduction_buffer.data.dptr);
  AmaxParams p;
  for (int iter = 0; iter < num_kernels; iter++) {
    size_t kernel_num_scales = 0;
351
352
    size_t kernel_num_tensors =
        (iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT;
353
354
355
356
357
    for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
      size_t i = iter * AMAX_PARAMS_LIMIT + pi;

      // Check tensors
      int num_scale = amax_histories[i]->data.shape[1];
358
359
360
361
      NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ",
                 dtype_name(amax_histories[i]->data.dtype), ".");
      NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ",
                 amax_histories[i]->data.shape.size(), " dims");
362
      NVTE_CHECK(amax_histories[i]->numel() == amax_history_length * num_scale, "Expected ",
363
                 amax_history_length * num_scale, " elements, ", "but found ",
364
                 amax_histories[i]->numel(), ".");
365
366
367
368
      NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ",
                 dtype_name(scales[i]->data.dtype), ".");
      NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(),
                 " dims");
369
370
      NVTE_CHECK(scales[i]->numel() == num_scale, "Expected ", num_scale, " elements, ", "Found ",
                 scales[i]->numel(), ".");
371
372
373
374
375
376
377
378
379
380
381

      // amax parameters
      kernel_num_scales += num_scale;
      p.param[pi].num_scale = num_scale;
      p.param[pi].amax_history = static_cast<float*>(amax_histories[i]->data.dptr);
      p.param[pi].scale = static_cast<float*>(scales[i]->data.dptr);
    }

    // Launch CUDA kernel
    size_t grid_size = kernel_num_tensors;
    const size_t block_size = amax_and_scale_update_impl::bsize;
382
383
    amax_and_scale_update_impl::kernel_bulk<<<grid_size, block_size, 0, stream>>>(
        amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max);
384
385
386
387
388
389
    NVTE_CHECK_CUDA(cudaGetLastError());

    // shift amax buffer pointer
    if (amax_buffer != nullptr) {
      amax_buffer += kernel_num_scales;
    }
390
    num_remaining_tensors -= AMAX_PARAMS_LIMIT;
391
392
393
  }
}

394
395
396
}  // namespace delayed_scaling_recipe
}  // namespace transformer_engine

397
void nvte_delayed_scaling_recipe_amax_and_scale_update(
398
399
    const NVTETensor amax_history, const NVTETensor scale, NVTETensor updated_amax_history,
    NVTETensor updated_scale, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
400
    cudaStream_t stream) {
401
402
403
  NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
  using namespace transformer_engine;
  delayed_scaling_recipe::amax_and_scale_update(
404
405
      *reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
      reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
406
      amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
407
}
408
409

void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
410
    const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
411
412
    std::vector<NVTETensor> scales, const char* amax_compute_algo, NVTEDType fp8_dtype,
    float margin, cudaStream_t stream) {
413
414
415
  NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
  using namespace transformer_engine;
  size_t num_tensors = amax_histories.size();
416
  std::vector<Tensor*> t_amax_histories, t_scales;
417
418
419
420
421
  for (size_t i = 0; i < num_tensors; i++) {
    t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
    t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
  }
  delayed_scaling_recipe::amax_and_scale_update_after_reduction(
422
      *reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
423
      amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
424
}