/************************************************************************* * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include #include #include #include #include "../common.h" #include "../util/logging.h" #include "../util/vectorized_pointwise.h" #include "recipe_common.cuh" #ifdef __HIP_PLATFORM_AMD__ #include using __nv_bfloat16 = __hip_bfloat16; constexpr int kColwiseReduceTileSize = 32; constexpr int THREADS_PER_BLOCK = 1024; #endif namespace transformer_engine { namespace { constexpr int amax_kernel_threads = 512; template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, const size_t N, const size_t num_aligned_elements) { VectorizedLoader loader(input, N); InputType max = 0.f; const int warp_id = threadIdx.x / THREADS_PER_WARP; const size_t M = num_aligned_elements; for (size_t tid = blockIdx.x * blockDim.x + threadIdx.x; tid < M; tid += gridDim.x * blockDim.x) { loader.load(tid, N); #pragma unroll for (int i = 0; i < nvec; ++i) { const InputType val = static_cast(loader.separate()[i]); __builtin_assume(max >= InputType{0.f}); if constexpr (std::is_same_v) { #if __CUDA_ARCH__ >= 800 max = __hmax(__habs(val), max); #else // Turing max = static_cast<__nv_bfloat16>( fmaxf(fabsf(static_cast(val)), static_cast(max))); #endif } else if constexpr (std::is_same_v) { max = __hmax(__habs(val), max); } else { max = fmaxf(fabsf(val), max); } } } // Reduce amax over block max = reduce_max(max, warp_id); if (threadIdx.x == 0) { atomicMaxFloat(amax, max); } } template __inline__ __device__ T WarpReduceMax(T val, int max = 32) { for (int offset = max; offset > 0; offset >>= 1) { val = fmaxf(__shfl_down(val, offset), val); } return val; } template __launch_bounds__(1024) __global__ void channel_colwise_amax_kernel(float *dst, const InputType *src, const float *fp8_scale, int M, int N) { __shared__ float g_shared[kColwiseReduceTileSize][kColwiseReduceTileSize]; const int j = blockIdx.x * blockDim.x + threadIdx.x; float channel_amax = 0.f; float scale = fp8_scale[0]; if (j < N) { for (int i = threadIdx.y; i < M; i += blockDim.y) { channel_amax = fmaxf(fabsf(static_cast(src[i * N + j]) * scale), channel_amax); } } g_shared[threadIdx.y][threadIdx.x] = channel_amax; __syncthreads(); float amax = g_shared[threadIdx.x][threadIdx.y]; amax = WarpReduceMax(amax, kColwiseReduceTileSize / 2); if (threadIdx.x == 0) { const int j = blockIdx.x * blockDim.x + threadIdx.y; if (j < N) { dst[j] = static_cast(amax) / 127.0; // scales } } } template __launch_bounds__(THREADS_PER_BLOCK) __global__ void channel_colwise_amax_kernel_v2(const InputType* in, float* out, const float* fp8_scale, int m, int n) { typedef hipcub::BlockReduce BlockReduce; __shared__ typename BlockReduce::TempStorage block_temp_storage; float scale = fp8_scale[0]; int BLOCKS_PER_COL = ceil(float(m) / THREADS_PER_BLOCK); int THREADS_PER_COL = BLOCKS_PER_COL * THREADS_PER_BLOCK; int idx = threadIdx.x + blockIdx.x * blockDim.x; int col_idx = idx / THREADS_PER_COL; int row_idx = idx % THREADS_PER_COL; float thread_data; if (row_idx < m) thread_data = fabsf((float)in[row_idx * n + col_idx] * scale); float local_amax; if (row_idx < (BLOCKS_PER_COL-1) * THREADS_PER_BLOCK) { local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max()); } else { local_amax = BlockReduce(block_temp_storage).Reduce(thread_data, hipcub::Max(), m - (BLOCKS_PER_COL - 1) * THREADS_PER_BLOCK); } if (threadIdx.x == 0) { atomicMax(&out[col_idx], local_amax); out[col_idx] = out[col_idx] / 127.0; } } template void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { // Zero out amax so we can update with atomic max cudaMemsetAsync(amax, 0, sizeof(float), stream); // Return immediately if tensor is empty if (N == 0) { return; } // Figure out alignment auto align = CheckAlignment(N, nvec, input); size_t num_aligned_elements = get_num_aligned_elements(input, N, nvec, sizeof(InputType)); // Figure out CUDA blocks constexpr size_t threads = amax_kernel_threads; size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); // Launch kernel switch (align) { case Alignment::SAME_ALIGNED: amax_kernel <<>>(input, amax, N, num_aligned_elements); break; case Alignment::SAME_UNALIGNED: amax_kernel <<>>(input, amax, N, num_aligned_elements); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. amax_kernel<1, true, InputType><<>>(input, amax, N, N); break; } } // Check results NVTE_CHECK_CUDA(cudaGetLastError()); } template void launch_channel_colwise_amax_kernel(const InputType *input, float *amax, const float *fp8_scale, const size_t M, const size_t N, cudaStream_t stream) { // Zero out amax so we can update with atomic max cudaMemsetAsync(amax, 0, N * sizeof(float), stream); // Launch kernel int B =(N - 1) / kColwiseReduceTileSize + 1; channel_colwise_amax_kernel<<>>(amax, input, fp8_scale, M, N); // Launch kernel v2 // dim3 block, grid; // int BLOCKS_PER_COL = ceil(float(M) / THREADS_PER_BLOCK); // block.x = THREADS_PER_BLOCK; // grid.x = BLOCKS_PER_COL * N; // hipLaunchKernelGGL((channel_colwise_amax_kernel_v2), dim3(grid), dim3(block), 0, stream, input, amax, fp8_scale, M, N); // Check results NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace } // namespace transformer_engine void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_amax); using namespace transformer_engine; // Check input tensor NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); const auto &input = *reinterpret_cast(input_); NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Input tensor for amax computation must unquantized, " "but got scaling_mode=", to_string(input.scaling_mode)); NVTE_CHECK(!is_fp8_dtype(input.data.dtype), "Input tensor for amax computation must be unquantized, but got dtype=", to_string(input.data.dtype)); NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data"); CheckInputTensor(input, "input_compute_amax"); // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *reinterpret_cast(output_); NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); NVTE_CHECK(output.amax.numel() == 1, "Output tensor for amax computation has invalid amax tensor " "(expected 1 entry, got shape=", output.amax.shape, ")"); NVTE_CHECK(output.amax.dptr != nullptr, "Output tensor for amax computation has amax tensor without data"); NVTE_CHECK(output.amax.dtype == DType::kFloat32, "Output tensor for amax computation has invalid amax tensor " "(expected FP32, got dtype=", to_string(output.amax.dtype), ")"); CheckOutputTensor(output, "output_compute_amax", true); // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), reinterpret_cast(output.amax.dptr), input.data.numel(), stream);); // NOLINT(*) } void nvte_compute_channel_colwise_amax(const NVTETensor input_, const NVTETensor output_, const NVTETensor fp8_scale_, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_channel_colwise_amax); using namespace transformer_engine; // Check input tensor NVTE_CHECK(input_ != nullptr, "Invalid input tensor (got NULL)"); NVTE_CHECK(fp8_scale_ != nullptr, "Invalid fp8 scale tensor (got NULL)"); const auto &input = *convertNVTETensorCheck(input_); const auto &fp8_scale = *convertNVTETensorCheck(fp8_scale_); NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Input tensor for amax computation must unquantized, " "but got scaling_mode=", to_string(input.scaling_mode)); NVTE_CHECK(input.data.dptr != nullptr, "Input tensor for amax computation has no data"); // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *convertNVTETensorCheck(output_); NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Output tensor for amax computation must be FP8 tensor with per-tensor scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); CheckOutputTensor(output, "output_compute_amax", true); // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_channel_colwise_amax_kernel(reinterpret_cast(input.data.dptr), reinterpret_cast(output.data.dptr), reinterpret_cast(fp8_scale.data.dptr), input.data.shape[0], input.data.shape[1], stream);); // NOLINT(*) } namespace transformer_engine { namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, const float epsilon) { *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); } } // namespace } // namespace transformer_engine void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConfig config_, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_scale_from_amax); using namespace transformer_engine; // Check output tensor NVTE_CHECK(output_ != nullptr, "Invalid output tensor (got NULL)"); auto &output = *reinterpret_cast(output_); NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Tensor must be FP8 tensor with per-tensor scaling, " "but got scaling_mode=", to_string(output.scaling_mode)); NVTE_CHECK(is_fp8_dtype(output.data.dtype), "Tensor must be FP8, but got dtype=", to_string(output.data.dtype)); NVTE_CHECK(output.amax.numel() == 1, "Tensor has invalid amax tensor (expected 1 entry, got shape=", output.amax.shape, ")"); NVTE_CHECK(output.amax.dptr != nullptr, "Tensor has amax tensor without data"); NVTE_CHECK(output.amax.dtype == DType::kFloat32, "Tensor has invalid amax tensor (expected FP32, got dtype=", to_string(output.amax.dtype), ")"); NVTE_CHECK(output.scale.numel() == 1, "Tensor has invalid scale tensor (expected 1 entry, got shape=", output.scale.shape, ")"); NVTE_CHECK(output.scale.dptr != nullptr, "Tensor has scale tensor without data"); NVTE_CHECK(output.scale.dtype == DType::kFloat32, "Tensor has invalid scale tensor (expected FP32, got dtype=", to_string(output.scale.dtype), ")"); // Check config NVTE_CHECK(config_ != nullptr, "Invalid config (got NULL)"); const auto &config = *reinterpret_cast(config_); // Maximum FP8 value float max_fp8 = 0.f; TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, max_fp8 = Quantized_Limits::max_norm;); // Update scale compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(output.amax.dptr), reinterpret_cast(output.scale.dptr), max_fp8, config.force_pow_2_scales, config.amax_epsilon); NVTE_CHECK_CUDA(cudaGetLastError()); }