delayed_scaling.cu 18.1 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/recipe.h>

#include <cmath>
10
#include <limits>
11
#include <string>
12
13

#include "../common.h"
14
#include "../util/cuda_runtime.h"
15
#include "../util/logging.h"
16
17
18
19
20
21
22
23
24
25
26

namespace transformer_engine {
namespace delayed_scaling_recipe {

namespace {

// amax value to use for updating scaling factor
enum class AmaxComputeAlgo { INVALID, MOST_RECENT, MAX };

const char* dtype_name(DType dtype) {
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, Type,
27
                                     return TypeInfo<Type>::name;);  // NOLINT(*)
28
29
30
31
32
33
  return "";
}

// Maximum representable value of an FP8 dtype
inline float fp8_dtype_max(DType dtype) {
  switch (dtype) {
34
35
36
37
38
39
    case DType::kFloat8E4M3:
      return 448;
    case DType::kFloat8E5M2:
      return 57344;
    default:
      NVTE_ERROR("Expected FP8 dtype, but got ", dtype_name(dtype));
40
41
42
43
  }
  return 0;
}

44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
// struct for amax parameters
struct AmaxParam {
  int num_scale = 0;
  float* amax_history = nullptr;
  float* scale = nullptr;
  float* scale_inv = nullptr;
};

// dummy struct for kernel_bulk's other params
struct OtherParams {
  float* a;
  size_t b;
  AmaxComputeAlgo c;
  float d;
};

#if CUDART_VERSION >= 12010
61
constexpr size_t max_constant_memory_per_kernel = 32768;
62
63
constexpr size_t AMAX_PARAMS_LIMIT =
    (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
64
#else
65
constexpr size_t max_constant_memory_per_kernel = 4096;
66
67
constexpr size_t AMAX_PARAMS_LIMIT =
    (max_constant_memory_per_kernel - sizeof(OtherParams)) / sizeof(AmaxParam);
68
69
70
71
72
73
#endif

struct AmaxParams {
  AmaxParam param[AMAX_PARAMS_LIMIT];
};

74
75
76
77
78
79
80
81
82
83
84
85
namespace amax_and_scale_update_impl {

// CUDA block size
constexpr size_t bsize = 256;

/* CUDA kernel to update amax history and FP8 scaling factors
 *
 * Block dims: bsize x 1 x 1
 *
 * Grid dims: num_scales x 1 x 1
 */
__global__ void __launch_bounds__(bsize)
86
87
88
89
    kernel(const float* amax_history_ptr, const float* scale_ptr, const float* scale_inv_ptr,
           const unsigned char* scale_inv_mask_ptr, float* updated_amax_history_ptr,
           float* updated_scale_ptr, float* updated_scale_inv_ptr, size_t amax_history_length,
           size_t amax_history_stride, AmaxComputeAlgo amax_compute_algo, float scaled_max) {
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  const size_t tid = threadIdx.x;
  const size_t bid = blockIdx.x;

  // Update amax
  float amax = 0;
  {
    // Roll amax history
    const auto* amax_history = amax_history_ptr + bid;
    auto* updated_amax_history = updated_amax_history_ptr + bid;
    const auto last_amax = amax_history[0];
    const auto& length = amax_history_length;
    const auto& stride = amax_history_stride;
    for (size_t off = 0; off < length; off += bsize) {
      const size_t i = off + tid;
      float a = 0;
      if (i < length) {
106
        a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
107
108
109
110
        amax = fmaxf(amax, a);
      }
      __syncthreads();  // In case roll is in-place
      if (i < length) {
111
        updated_amax_history[i * stride] = (i > 0) ? a : 0;
112
113
114
115
116
      }
    }

    // Compute amax to use for scaling factor
    switch (amax_compute_algo) {
117
118
119
120
      case AmaxComputeAlgo::MOST_RECENT:
        amax = last_amax;
        break;
      case AmaxComputeAlgo::MAX: {
121
122
123
124
125
126
127
128
129
130
131
        __shared__ float shared_amax[bsize];
        shared_amax[tid] = amax;
        __syncthreads();
#pragma unroll
        for (size_t off = bsize / 2; off > 0; off /= 2) {
          if (tid < off) {
            shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
          }
          __syncthreads();
        }
        amax = shared_amax[tid];
132
133
134
      } break;
      default:
        amax = 0;
135
136
137
138
139
140
141
142
143
144
145
146
    }
  }

  // Update scale and scale inverse
  if (tid == 0) {
    // Update scale
    float scale;
    if (isfinite(amax) && amax > 0) {
      scale = scaled_max / amax;
    } else {
      scale = scale_ptr[bid];
    }
147
148
149
150
151
    // When the amax is too tiny that the scale becoming infinite in FP32,
    // we set the scale to the max value of FP32. In this case, the tensor’s
    // amax won't get mapped to the FP8 max representable, but rather
    // something below that, but this is the best thing we can do.
    if (isinf(scale)) {
152
      scale = std::numeric_limits<float>::max();
153
    }
154
155
156
157
158
159
160
161
162
163
164
165
166
    updated_scale_ptr[bid] = scale;

    // Update scale inverse
    float scale_inv;
    if (scale_inv_mask_ptr == nullptr || scale_inv_mask_ptr[bid]) {
      scale_inv = 1 / scale;
    } else {
      scale_inv = scale_inv_ptr[bid];
    }
    updated_scale_inv_ptr[bid] = scale_inv;
  }
}

167
168
169
170
171
172
173
/* CUDA kernel to bulk-update amax history and FP8 scaling factors
 *
 * Block dims: bsize x 1 x 1
 *
 * Grid dims: num_tensors x 1 x 1
 */
__global__ void __launch_bounds__(bsize)
174
175
    kernel_bulk(float* amax_reduction_buffer, AmaxParams p, size_t amax_history_length,
                AmaxComputeAlgo amax_compute_algo, float scaled_max) {
176
177
178
179
180
181
182
183
  const size_t bid = blockIdx.x;
  const size_t tid = threadIdx.x;
  const int num_scale = p.param[bid].num_scale;

  int offset_in_buffer = 0;
  for (int j = 0; j < bid; j++) {
    offset_in_buffer += p.param[j].num_scale;
  }
184

185
186
187
188
189
190
191
  for (int count = 0; count < num_scale; count++) {
    // Update amax
    float amax = 0;
    {
      // Roll amax history
      const auto& length = amax_history_length;
      const auto& stride = p.param[bid].num_scale;
192
193
194
195
196
      auto* amax_history = p.param[bid].amax_history + count;
      const auto last_amax = ((amax_reduction_buffer != nullptr) &&
                              (amax_reduction_buffer[offset_in_buffer + count] != 0.0f))
                                 ? amax_reduction_buffer[offset_in_buffer + count]
                                 : amax_history[0];
197
198
199
200
201
      if (last_amax != 0.0f) {
        for (size_t off = 0; off < length; off += bsize) {
          const size_t i = off + tid;
          float a = 0;
          if (i < length) {
202
            a = (i < length - 1) ? amax_history[(i + 1) * stride] : last_amax;
203
204
205
206
            amax = fmaxf(amax, a);
          }
          __syncthreads();  // Inplace roll
          if (i < length) {
207
            amax_history[i * stride] = (i > 0) ? a : 0;
208
          }
209
210
211
212
213
        }
      }

      // Compute amax to use for scaling factor
      switch (amax_compute_algo) {
214
215
216
217
        case AmaxComputeAlgo::MOST_RECENT:
          amax = last_amax;
          break;
        case AmaxComputeAlgo::MAX: {
218
219
220
221
222
223
224
225
226
227
228
          __shared__ float shared_amax[bsize];
          shared_amax[tid] = amax;
          __syncthreads();
#pragma unroll
          for (size_t off = bsize / 2; off > 0; off /= 2) {
            if (tid < off) {
              shared_amax[tid] = fmaxf(shared_amax[tid], shared_amax[tid + off]);
            }
            __syncthreads();
          }
          amax = shared_amax[tid];
229
230
231
        } break;
        default:
          amax = 0;
232
233
234
235
236
      }
    }

    // Update scale and scale inverse
    if (tid == 0) {
237
238
239
240
241
242
243
244
245
246
247
      // Computing the scaling factor requires consideration of the following scenarios:
      // 1. amax == 0:
      //    No action is possible, set scale to the previous scale (or 1).
      // 2. 0 < amax < tiny_amax
      //    The amax is too tiny that the scale becomes infinite in FP32.
      //    Set scale = FP32_max
      // 3. tiny_amax <= amax < FP32_max:
      //    Set scale = FP8_max (or scaled_max) / amax
      // 4. When amax == inf or amax == nan:
      //    No action is possible, set scale to the previous scale (or 1).

248
249
250
251
252
253
      float scale;
      if (isfinite(amax) && amax > 0) {
        scale = scaled_max / amax;
      } else {
        scale = p.param[bid].scale[count];
      }
254
255
256
257
258
      // When the amax is too tiny that the scale becoming infinite in FP32,
      // we set the scale to the max value of FP32. In this case, the tensor’s
      // amax won't get mapped to the FP8 max representable, but rather
      // something below that, but this is the best thing we can do.
      if (isinf(scale)) {
259
        scale = std::numeric_limits<float>::max();
260
      }
261
262
263
264
265
266
267
      p.param[bid].scale[count] = scale;
      p.param[bid].scale_inv[count] = 1 / scale;
    }
  }
}

}  // namespace amax_and_scale_update_impl
268
269
270

}  // namespace

271
272
273
274
void amax_and_scale_update(const Tensor& amax_history, const Tensor& scale, const Tensor& scale_inv,
                           const Tensor& scale_inv_mask, Tensor* updated_amax_history_,
                           Tensor* updated_scale_, Tensor* updated_scale_inv_,
                           const std::string& amax_compute_algo, DType fp8_dtype, float margin,
275
276
277
278
279
280
                           cudaStream_t stream) {
  auto& updated_amax_history = *updated_amax_history_;
  auto& updated_scale = *updated_scale_;
  auto& updated_scale_inv = *updated_scale_inv_;

  // Number of elements in tensor
281
  auto numel = [](const Tensor& tensor) -> size_t {
282
283
284
285
286
287
288
289
    size_t acc = 1;
    for (const auto& dim : tensor.data.shape) {
      acc *= dim;
    }
    return acc;
  };

  // Check tensors
290
291
  NVTE_CHECK(amax_history.data.shape.size() == 2, "Found ", amax_history.data.shape.size(),
             " dims");
292
293
  const size_t amax_history_length = amax_history.data.shape[0];
  const size_t num_scales = amax_history.data.shape[1];
294
295
296
297
298
  NVTE_CHECK(amax_history.data.dtype == DType::kFloat32, "Found ",
             dtype_name(amax_history.data.dtype), ".");
  NVTE_CHECK(numel(scale) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
             numel(scale), ".");
  NVTE_CHECK(scale.data.dtype == DType::kFloat32, "Found ", dtype_name(scale.data.dtype), ".");
299
  if (scale_inv_mask.data.dptr != nullptr) {
300
301
    NVTE_CHECK(numel(scale_inv) == num_scales, "Expected ", num_scales, " elements, ", "but found ",
               numel(scale_inv), ".");
302
    NVTE_CHECK(scale_inv.data.dtype == DType::kFloat32);
303
    NVTE_CHECK(numel(scale_inv_mask) == num_scales, "Expected ", num_scales, " elements, ",
304
               "but found ", numel(scale_inv_mask), ".");
305
306
    NVTE_CHECK(scale_inv_mask.data.dtype == DType::kByte, "Found ",
               dtype_name(scale_inv_mask.data.dtype), ".");
307
  }
308
309
310
311
312
  NVTE_CHECK(updated_amax_history.data.shape.size() == 2, "Found ",
             updated_amax_history.data.shape.size(), " dims.");
  NVTE_CHECK(updated_amax_history.data.shape[0] == amax_history_length, "Expected ",
             amax_history_length, ", ", "but found ", updated_amax_history.data.shape[0]);
  NVTE_CHECK(updated_amax_history.data.shape[1] == num_scales, "Expected ", num_scales, ", ",
313
             "but found ", updated_amax_history.data.shape[1]);
314
315
316
  NVTE_CHECK(updated_amax_history.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_amax_history.data.dtype), ".");
  NVTE_CHECK(numel(updated_scale) == num_scales, "Expected ", num_scales, " elements, ",
317
             "but found ", numel(updated_scale), ".");
318
319
320
  NVTE_CHECK(updated_scale.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_scale.data.dtype), ".");
  NVTE_CHECK(numel(updated_scale_inv) == num_scales, "Expected ", num_scales, " elements, ",
321
             "but found ", numel(updated_scale_inv), ".");
322
323
  NVTE_CHECK(updated_scale_inv.data.dtype == DType::kFloat32, "Got ",
             dtype_name(updated_scale_inv.data.dtype), ".");
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340

  // amax value to use for updating scaling factor
  AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
  if (amax_compute_algo == "max") {
    amax_compute_algo_ = AmaxComputeAlgo::MAX;
  } else if (amax_compute_algo == "most_recent") {
    amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
  } else {
    NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
  }

  // Expected maximum value after scale is applied
  const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);

  // Launch CUDA kernel
  constexpr size_t block_size = amax_and_scale_update_impl::bsize;
  const size_t grid_size = num_scales;
341
342
  amax_and_scale_update_impl::kernel<<<grid_size, block_size, 0, stream>>>(
      static_cast<const float*>(amax_history.data.dptr), static_cast<const float*>(scale.data.dptr),
343
344
345
346
      static_cast<const float*>(scale_inv.data.dptr),
      static_cast<const unsigned char*>(scale_inv_mask.data.dptr),
      static_cast<float*>(updated_amax_history.data.dptr),
      static_cast<float*>(updated_scale.data.dptr),
347
348
      static_cast<float*>(updated_scale_inv.data.dptr), amax_history_length, num_scales,
      amax_compute_algo_, scaled_max);
349
350
351
  NVTE_CHECK_CUDA(cudaGetLastError());
}

352
void amax_and_scale_update_after_reduction(const Tensor& amax_reduction_buffer,
353
354
355
                                           std::vector<Tensor*> amax_histories,
                                           std::vector<Tensor*> scales,
                                           std::vector<Tensor*> scale_invs,
356
357
                                           const std::string& amax_compute_algo, DType fp8_dtype,
                                           float margin, cudaStream_t stream) {
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
  using namespace transformer_engine;

  // amax value to use for updating scaling factor
  AmaxComputeAlgo amax_compute_algo_ = AmaxComputeAlgo::INVALID;
  if (amax_compute_algo == "max") {
    amax_compute_algo_ = AmaxComputeAlgo::MAX;
  } else if (amax_compute_algo == "most_recent") {
    amax_compute_algo_ = AmaxComputeAlgo::MOST_RECENT;
  } else {
    NVTE_ERROR("Unsupported amax compute algorithm (", amax_compute_algo, ")");
  }

  // Expected maximum value after scale is applied
  const float scaled_max = fp8_dtype_max(fp8_dtype) * std::pow(2.f, -margin);

  // Number of elements in tensor
374
  auto numel = [](const Tensor* tensor) -> size_t {
375
376
377
378
379
380
381
382
383
    size_t acc = 1;
    for (const auto& dim : tensor->data.shape) {
      acc *= dim;
    }
    return acc;
  };

  // Number of tensors in the bulk
  const size_t num_tensors = amax_histories.size();
384
  size_t num_remaining_tensors = num_tensors;
385
  const int num_kernels = (num_tensors + AMAX_PARAMS_LIMIT - 1) / AMAX_PARAMS_LIMIT;
386
387
388
389
390
391
392
393
394
395
  size_t amax_history_length = 0;
  if (num_tensors > 0) {
    amax_history_length = amax_histories[0]->data.shape[0];
  }

  // amax parameters
  float* amax_buffer = static_cast<float*>(amax_reduction_buffer.data.dptr);
  AmaxParams p;
  for (int iter = 0; iter < num_kernels; iter++) {
    size_t kernel_num_scales = 0;
396
397
    size_t kernel_num_tensors =
        (iter == (num_kernels - 1)) ? num_remaining_tensors : AMAX_PARAMS_LIMIT;
398
399
400
401
402
    for (size_t pi = 0; pi < kernel_num_tensors; pi++) {
      size_t i = iter * AMAX_PARAMS_LIMIT + pi;

      // Check tensors
      int num_scale = amax_histories[i]->data.shape[1];
403
404
405
406
407
408
409
410
411
412
413
414
415
      NVTE_CHECK(amax_histories[i]->data.dtype == DType::kFloat32, "Found ",
                 dtype_name(amax_histories[i]->data.dtype), ".");
      NVTE_CHECK(amax_histories[i]->data.shape.size() == 2, "Found ",
                 amax_histories[i]->data.shape.size(), " dims");
      NVTE_CHECK(numel(amax_histories[i]) == amax_history_length * num_scale, "Expected ",
                 amax_history_length * num_scale, " elements, ", "but found ",
                 numel(amax_histories[i]), ".");
      NVTE_CHECK(scales[i]->data.dtype == DType::kFloat32, "Found ",
                 dtype_name(scales[i]->data.dtype), ".");
      NVTE_CHECK(scales[i]->data.shape.size() == 1, "Found ", scales[i]->data.shape.size(),
                 " dims");
      NVTE_CHECK(numel(scales[i]) == num_scale, "Expected ", num_scale, " elements, ", "Found ",
                 numel(scales[i]), ".");
416
417
418
419
420
421
422
423
424
425
426
427

      // amax parameters
      kernel_num_scales += num_scale;
      p.param[pi].num_scale = num_scale;
      p.param[pi].amax_history = static_cast<float*>(amax_histories[i]->data.dptr);
      p.param[pi].scale = static_cast<float*>(scales[i]->data.dptr);
      p.param[pi].scale_inv = static_cast<float*>(scale_invs[i]->data.dptr);
    }

    // Launch CUDA kernel
    size_t grid_size = kernel_num_tensors;
    const size_t block_size = amax_and_scale_update_impl::bsize;
428
429
    amax_and_scale_update_impl::kernel_bulk<<<grid_size, block_size, 0, stream>>>(
        amax_buffer, p, amax_history_length, amax_compute_algo_, scaled_max);
430
431
432
433
434
435
    NVTE_CHECK_CUDA(cudaGetLastError());

    // shift amax buffer pointer
    if (amax_buffer != nullptr) {
      amax_buffer += kernel_num_scales;
    }
436
    num_remaining_tensors -= AMAX_PARAMS_LIMIT;
437
438
439
  }
}

440
441
442
}  // namespace delayed_scaling_recipe
}  // namespace transformer_engine

443
444
445
446
447
void nvte_delayed_scaling_recipe_amax_and_scale_update(
    const NVTETensor amax_history, const NVTETensor scale, const NVTETensor scale_inv,
    const NVTETensor scale_inv_mask, NVTETensor updated_amax_history, NVTETensor updated_scale,
    NVTETensor updated_scale_inv, const char* amax_compute_algo, NVTEDType fp8_dtype, float margin,
    cudaStream_t stream) {
448
449
450
  NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update);
  using namespace transformer_engine;
  delayed_scaling_recipe::amax_and_scale_update(
451
452
453
454
455
      *reinterpret_cast<const Tensor*>(amax_history), *reinterpret_cast<const Tensor*>(scale),
      *reinterpret_cast<const Tensor*>(scale_inv), *reinterpret_cast<const Tensor*>(scale_inv_mask),
      reinterpret_cast<Tensor*>(updated_amax_history), reinterpret_cast<Tensor*>(updated_scale),
      reinterpret_cast<Tensor*>(updated_scale_inv), amax_compute_algo,
      static_cast<DType>(fp8_dtype), margin, stream);
456
}
457
458

void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
459
460
461
    const NVTETensor amax_reduction_buffer, std::vector<NVTETensor> amax_histories,
    std::vector<NVTETensor> scales, std::vector<NVTETensor> scale_invs,
    const char* amax_compute_algo, NVTEDType fp8_dtype, float margin, cudaStream_t stream) {
462
463
464
465
466
467
468
469
470
471
  NVTE_API_CALL(nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction);
  using namespace transformer_engine;
  size_t num_tensors = amax_histories.size();
  std::vector<Tensor*> t_amax_histories, t_scales, t_scale_invs;
  for (size_t i = 0; i < num_tensors; i++) {
    t_amax_histories.push_back(reinterpret_cast<Tensor*>(amax_histories[i]));
    t_scales.push_back(reinterpret_cast<Tensor*>(scales[i]));
    t_scale_invs.push_back(reinterpret_cast<Tensor*>(scale_invs[i]));
  }
  delayed_scaling_recipe::amax_and_scale_update_after_reduction(
472
473
      *reinterpret_cast<const Tensor*>(amax_reduction_buffer), t_amax_histories, t_scales,
      t_scale_invs, amax_compute_algo, static_cast<DType>(fp8_dtype), margin, stream);
474
}