current_scaling.cu 15 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/recipe.h>

#include <algorithm>
#include <limits>
#include <type_traits>

#include "../common.h"
#include "../util/logging.h"
#include "../util/vectorized_pointwise.h"
16
#include "recipe_common.cuh"
17

yuguo's avatar
yuguo committed
18
#ifdef __HIP_PLATFORM_AMD__
19
#include <hipcub/hipcub.hpp>
yuguo's avatar
yuguo committed
20
using __nv_bfloat16 = __hip_bfloat16;
21
22
constexpr int kColwiseReduceTileSize = 32;
constexpr int THREADS_PER_BLOCK = 1024;
yuguo's avatar
yuguo committed
23
24
#endif

25
26
27
28
29
30
31
32
namespace transformer_engine {
namespace {

constexpr int amax_kernel_threads = 512;

template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
    void amax_kernel(const InputType *input, float *amax, const size_t N,
33
34
35
36
37
                     const size_t num_aligned_elements, const float *noop_ptr) {
  if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
    return;
  }

38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  VectorizedLoader<InputType, nvec, aligned> loader(input, N);
  InputType max = 0.f;
  const int warp_id = threadIdx.x / THREADS_PER_WARP;
  const size_t M = num_aligned_elements;

  for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) {
    loader.load(tid, N);
#pragma unroll
    for (int i = 0; i < nvec; ++i) {
      const InputType val = static_cast<InputType>(loader.separate()[i]);
      __builtin_assume(max >= InputType{0.f});
      if constexpr (std::is_same_v<InputType, __nv_bfloat16>) {
#if __CUDA_ARCH__ >= 800
        max = __hmax(__habs(val), max);
#else  // Turing
        max = static_cast<__nv_bfloat16>(
            fmaxf(fabsf(static_cast<float>(val)), static_cast<float>(max)));
#endif
      } else if constexpr (std::is_same_v<InputType, __half>) {
        max = __hmax(__habs(val), max);
      } else {
        max = fmaxf(fabsf(val), max);
      }
    }
  }

  // Reduce amax over block
  max = reduce_max<amax_kernel_threads / THREADS_PER_WARP>(max, warp_id);
  if (threadIdx.x == 0) {
    atomicMaxFloat(amax, max);
  }
}

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
template <typename T>
__inline__ __device__ T WarpReduceMax(T val, int max = 32) {
  for (int offset = max; offset > 0; offset >>= 1) {
    val = fmaxf(__shfl_down(val, offset), val);
  }
  return val;
}

template <int nvec, typename InputType>
__launch_bounds__(1024) __global__
void channel_colwise_amax_kernel(float *dst, const InputType *src, const float *fp8_scale, int M, int N) {
    __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize];
    const int j = blockIdx.x * blockDim.x + threadIdx.x;
    float channel_amax = 0.f;
    float scale = fp8_scale[0];
    if (j < N) {
        for (int i = threadIdx.y; i < M; i += blockDim.y) {
            channel_amax = fmaxf(fabsf(static_cast<float>(src[i * N + j]) * scale), channel_amax);
        }
    }
    g_shared[threadIdx.y][threadIdx.x] = channel_amax;
    __syncthreads();
    float amax = g_shared[threadIdx.x][threadIdx.y];
    amax = WarpReduceMax<float>(amax, kColwiseReduceTileSize / 2);
    if (threadIdx.x == 0) {
        const int j = blockIdx.x * blockDim.x + threadIdx.y;
        if (j < N) {
            dst[j] = static_cast<float>(amax) / 127.0; // scales
        }
   }
}

template <typename InputType>
__launch_bounds__(THREADS_PER_BLOCK) __global__
void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float* fp8_scale, int m, int n) {
  typedef hipcub::BlockReduce<float, THREADS_PER_BLOCK> BlockReduce;
  __shared__ typename BlockReduce::TempStorage block_temp_storage;
  float scale = fp8_scale[0];

  int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK);
  int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK;
  int idx = threadIdx.x + blockIdx.x * blockDim.x;
  int col_idx = idx / THREADS_PER_COL;
  int row_idx = idx % THREADS_PER_COL;
  float thread_data;
  if (row_idx < m)
    thread_data = fabsf((float)in[row_idx * n + col_idx] * scale);
  float local_amax;
  if (row_idx < (BLOCKS_PER_COL-1) * THREADS_PER_BLOCK) {
    local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max());
  } else {
    local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max(), m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK);
  }
  if (threadIdx.x == 0) {
    atomicMax(&out[col_idx], local_amax);  
    out[col_idx] = out[col_idx] / 127.0;
  }
}

130
template <int nvec, typename InputType>
131
132
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
                        cudaStream_t stream) {
133
  // Zero out amax so we can update with atomic max
134
  NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154

  // Return immediately if tensor is empty
  if (N == 0) {
    return;
  }

  // Figure out alignment
  auto align = CheckAlignment(N, nvec, input);
  size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType));

  // Figure out CUDA blocks
  constexpr size_t threads = amax_kernel_threads;
  size_t num_blocks = DIVUP(num_aligned_elements, threads);
  constexpr size_t max_blocks = 65535;
  num_blocks = std::min(num_blocks, max_blocks);

  // Launch kernel
  switch (align) {
    case Alignment::SAME_ALIGNED:
      amax_kernel<nvec, true, InputType>
155
          <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
156
157
158
      break;
    case Alignment::SAME_UNALIGNED:
      amax_kernel<nvec, false, InputType>
159
          <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
160
161
162
163
      break;
    case Alignment::DIFFERENT: {
      // This case is a logic error, since there is only one pointer (input)
      // in the alignment check. Still safe to process without vectorization.
164
165
      amax_kernel<1, true, InputType>
          <<<num_blocks, threads, 0, stream>>>(input, amax, N, N, noop_ptr);
166
167
168
169
170
171
172
173
      break;
    }
  }

  // Check results
  NVTE_CHECK_CUDA(cudaGetLastError());
}

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
template <int nvec, typename InputType>
void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, const float *fp8_scale, const size_t M, const size_t N, cudaStream_t stream) {
  // Zero out amax so we can update with atomic max
  cudaMemsetAsync(amax, 0, N * sizeof(float), stream);

  // Launch kernel
  int B =(N - 1) / kColwiseReduceTileSize + 1;
  channel_colwise_amax_kernel<nvec, InputType><<<B, dim3(kColwiseReduceTileSize, kColwiseReduceTileSize), 0, stream>>>(amax, input, fp8_scale, M, N);

  // Launch kernel v2
  // dim3 block, grid;
  // int BLOCKS_PER_COL = ceil(float(M) / THREADS_PER_BLOCK);
  // block.x = THREADS_PER_BLOCK;
  // grid.x = BLOCKS_PER_COL * N;
  // hipLaunchKernelGGL((channel_colwise_amax_kernel_v2<InputType>), dim3(grid), dim3(block), 0, stream, input, amax, fp8_scale, M, N);

  // Check results
  NVTE_CHECK_CUDA(cudaGetLastError());
}

194
195
196
}  // namespace
}  // namespace transformer_engine

197
198
199
200
namespace {

void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream,
                       const NVTEQuantizationConfig config_) {
201
202
203
204
  using namespace transformer_engine;

  // Check input tensor
  NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
205
  const auto &input = *convertNVTETensorCheck(input_);
206
207
208
209
210
211
212
213
214
215
216
217
  NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
             "Input tensor for amax computation must unquantized, "
             "but got scaling_mode=",
             to_string(input.scaling_mode));
  NVTE_CHECK(!is_fp8_dtype(input.data.dtype),
             "Input tensor for amax computation must be unquantized, but got dtype=",
             to_string(input.data.dtype));
  NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data");
  CheckInputTensor(input, "input_compute_amax");

  // Check output tensor
  NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
218
  auto &output = *convertNVTETensorCheck(output_);
219
220
221
222
223
224
225
226
227
228
229
230
231
232
  NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
             "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
             "but got scaling_mode=",
             to_string(output.scaling_mode));
  NVTE_CHECK(output.amax.numel() == 1,
             "Output tensor for amax computation has invalid amax tensor "
             "(expected 1 entry, got shape=",
             output.amax.shape, ")");
  NVTE_CHECK(output.amax.dptr != nullptr,
             "Output tensor for amax computation has amax tensor without data");
  NVTE_CHECK(output.amax.dtype == DType::kFloat32,
             "Output tensor for amax computation has invalid amax tensor  "
             "(expected FP32, got dtype=",
             to_string(output.amax.dtype), ")");
233
  CheckOutputTensor(output, "output_compute_amax", true);
234

235
236
237
238
239
240
241
242
243
244
  float *noop_ptr = nullptr;
  if (config_ != nullptr) {
    const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);

    // extract noop tensor from quant_config_cpp if it's not null
    const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
    noop_ptr = reinterpret_cast<float *>(
        (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
  }

245
246
247
248
249
  // Compute amax
  TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
      input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
      launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
                               reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
250
251
252
253
254
255
256
257
258
259
260
261
262
263
                               noop_ptr, stream););  // NOLINT(*)
}

}  // anonymous namespace

void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
  NVTE_API_CALL(nvte_compute_amax);
  compute_amax_impl(input_, output_, stream, nullptr);
}

void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_,
                                   const NVTEQuantizationConfig config_, cudaStream_t stream) {
  NVTE_API_CALL(nvte_compute_amax_with_config);
  compute_amax_impl(input_, output_, stream, config_);
264
265
}

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
void nvte_compute_channel_colwise_amax(const NVTETensor input_, const NVTETensor output_, const NVTETensor fp8_scale_, cudaStream_t stream) {
  NVTE_API_CALL(nvte_compute_channel_colwise_amax);
  using namespace transformer_engine;

  // Check input tensor
  NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)");
  NVTE_CHECK(fp8_scale_ != nullptr, "Invalid fp8 scale tensor (got NULL)");
  const auto &input = *convertNVTETensorCheck(input_);
  const auto &fp8_scale = *convertNVTETensorCheck(fp8_scale_);
  NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
             "Input tensor for amax computation must unquantized, "
             "but got scaling_mode=",
             to_string(input.scaling_mode));
  NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data");

  // Check output tensor
  NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
  auto &output = *convertNVTETensorCheck(output_);
  NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
             "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, "
             "but got scaling_mode=",
             to_string(output.scaling_mode));
  CheckOutputTensor(output, "output_compute_amax", true);

  // Compute amax
  TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
      input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
      launch_channel_colwise_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
                               reinterpret_cast<float *>(output.data.dptr), reinterpret_cast<const float *>(fp8_scale.data.dptr),
                               input.data.shape[0],
                               input.data.shape[1],
                               stream););  // NOLINT(*)
}

300
301
302
303
304
namespace transformer_engine {
namespace {

__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
                                               const float max_fp8, const bool force_pow_2_scales,
305
306
307
308
309
                                               const float epsilon, const float *noop_ptr) {
  if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
    return;
  }

310
311
  *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
                                       std::numeric_limits<float>::max());
312
313
314
315
316
317
318
319
320
321
322
323
}

}  // namespace
}  // namespace transformer_engine

void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_,
                                  cudaStream_t stream) {
  NVTE_API_CALL(nvte_compute_scale_from_amax);
  using namespace transformer_engine;

  // Check output tensor
  NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)");
324
  auto &output = *convertNVTETensorCheck(output_);
325
326
327
328
  NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
             "Tensor must be FP8 tensor with per-tensor scaling, "
             "but got scaling_mode=",
             to_string(output.scaling_mode));
yuguo's avatar
yuguo committed
329
330
  NVTE_CHECK(is_fp8_dtype(output.data.dtype) || is_int8_dtype(output.data.dtype),
             "Tensor must be FP8 or INT8, but got dtype=", to_string(output.data.dtype));
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
  NVTE_CHECK(output.amax.numel() == 1,
             "Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape,
             ")");
  NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data");
  NVTE_CHECK(output.amax.dtype == DType::kFloat32,
             "Tensor has invalid amax tensor (expected FP32, got dtype=",
             to_string(output.amax.dtype), ")");
  NVTE_CHECK(output.scale.numel() == 1,
             "Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape,
             ")");
  NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data");
  NVTE_CHECK(output.scale.dtype == DType::kFloat32,
             "Tensor has invalid scale tensor (expected FP32, got dtype=",
             to_string(output.scale.dtype), ")");

  // Check config
  NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)");
  const auto &config = *reinterpret_cast<const QuantizationConfig *>(config_);

  // Maximum FP8 value
  float max_fp8 = 0.f;
yuguo's avatar
yuguo committed
352
  TRANSFORMER_ENGINE_TYPE_SWITCH_8BIT(output.data.dtype, DType,
353
354
                                         max_fp8 = Quantized_Limits<DType>::max_norm;);

355
356
357
358
359
360
361
362
363
364
365
  // noop tensor for cuda graph
  float *noop_ptr = nullptr;
  if (config_ != nullptr) {
    const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);

    // extract noop tensor from quant_config_cpp if it's not null
    const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
    noop_ptr = reinterpret_cast<float *>(
        (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
  }

366
  // Update scale
367
368
369
  compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
      reinterpret_cast<const float *>(output.amax.dptr),
      reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
370
      config.amax_epsilon, noop_ptr);
371
372
  NVTE_CHECK_CUDA(cudaGetLastError());
}