Commit 87e3e56e authored by yuguo's avatar yuguo
Browse files

Merge commit '734bcedd' of...

Merge commit '734bcedd' of https://github.com/NVIDIA/TransformerEngine
parents 2f11bd2e 734bcedd
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "utils.cuh"
using namespace transformer_engine;
namespace {
// Parameters
using VectorType = BytesToType<__VECTOR_SIZE__>::Type;
constexpr size_t block_size = __BLOCK_SIZE__;
} // namespace
__global__ void __launch_bounds__(block_size)
swap_first_dims_kernel(const VectorType* __restrict__ const input,
VectorType* __restrict__ const output, const size_t dim0,
const size_t dim1, const size_t dim2) {
const size_t gid = threadIdx.x + blockIdx.x * block_size;
#if __SINGLE_LOAD_STORE__
const auto idx = gid;
#else
const size_t nthreads = gridDim.x * block_size;
for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads)
#endif // __SINGLE_LOAD_STORE__
{
const auto idx2 = idx % dim2;
const auto idx1 = (idx / dim2) % dim1;
const auto idx0 = (idx / dim2) / dim1;
const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2;
output[idx] = input[in_offset];
}
}
/*************************************************************************
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda_runtime.h>
#include <transformer_engine/transpose.h>
#include <algorithm>
#include <cstdint>
#include <string>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/logging.h"
#include "../util/rtc.h"
#include "../util/string.h"
namespace transformer_engine {
namespace {
// String with RTC kernel implementation
#include "string_code_transpose_rtc_swap_first_dims_cu.h"
// Hard-coded kernel parameters
constexpr size_t block_size = 128;
/* Performance heuristics for optimized kernel parameters */
struct KernelConfig {
/* Vector load/store size */
size_t vector_size;
/* Whether config is valid */
bool valid = false;
/* Number of CUDA blocks */
size_t num_blocks = 0;
/* Whether each thread needs to make exactly one load/store */
bool single_load_store = true;
/* Number of active SMs */
size_t active_sm_count = 0;
/* Used bytes per L1 cache load */
size_t bytes_per_load = 0;
/* Used bytes per L1 cache store */
size_t bytes_per_store = 0;
KernelConfig(size_t dim0, size_t dim1, size_t dim2, size_t sm_count, size_t vector_size_)
: vector_size{vector_size_} {
// Check that tiles are correctly aligned
if (dim2 % vector_size_ != 0) {
return;
}
valid = true;
// Number of CUDA blocks
num_blocks = DIVUP(dim0 * dim1 * dim2 / vector_size, block_size);
if (num_blocks > 2147483647ull) {
// Maximum number of CUDA blocks
single_load_store = false;
num_blocks = 2147483647ull;
} else if (num_blocks * block_size != dim0 * dim1 * dim2 / vector_size) {
single_load_store = false;
}
// SM occupancy
constexpr size_t warp_size = 32;
constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs
active_sm_count = std::min(DIVUP(num_blocks * block_size / warp_size, warps_per_sm), sm_count);
// L1 cache efficiency
constexpr size_t cache_line_size = 128;
bytes_per_store = std::min(cache_line_size, warp_size * vector_size); // Contiguous writes
bytes_per_load = bytes_per_store;
if (dim2 % (vector_size * warp_size) != 0) {
// Some warps are reading from two non-contiguous regions
bytes_per_load /= 2;
}
}
/* Compare by estimated cost */
bool operator<(const KernelConfig &other) const {
if (this->valid && other.valid) {
// cost ~ (1/bytes_per_load + 1/bytes_per_store) / active_sms
// Note: Integer arithmetic ensures stable ordering
const auto &l1 = this->bytes_per_load;
const auto &s1 = this->bytes_per_store;
const auto &p1 = this->active_sm_count;
const auto &l2 = other.bytes_per_load;
const auto &s2 = other.bytes_per_store;
const auto &p2 = other.active_sm_count;
const auto scale = l1 * s1 * p1 * l2 * s2 * p2;
const auto cost1 = (scale / l1 + scale / s1) / p1;
const auto cost2 = (scale / l2 + scale / s2) / p2;
return cost1 < cost2;
} else {
return this->valid && !other.valid;
}
}
};
template <typename Type>
__global__ void __launch_bounds__(block_size)
swap_first_dims_untuned_kernel(const Type *__restrict__ input, Type *__restrict__ output,
const size_t dim0, const size_t dim1, const size_t dim2) {
const size_t gid = threadIdx.x + blockIdx.x * block_size;
const size_t nthreads = gridDim.x * block_size;
for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads) {
const auto idx2 = idx % dim2;
const auto idx1 = (idx / dim2) % dim1;
const auto idx0 = (idx / dim2) / dim1;
const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2;
output[idx] = input[in_offset];
}
}
} // namespace
void swap_first_dims(const Tensor &input, Tensor &output, cudaStream_t stream) {
// Check tensors
NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Input tensor must be simple tensor, but scaling mode is ",
to_string(input.scaling_mode), ".");
NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING,
"Output tensor must be simple tensor, but scaling mode is ",
to_string(output.scaling_mode), ".");
NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()),
") and output tensor (dtype=", to_string(output.dtype()), ") do not match.");
NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated.");
NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated.");
// Check tensor dimensions
const auto input_shape = input.shape();
const auto output_shape = output.shape();
NVTE_CHECK(input_shape.size() >= 2, "Invalid input tensor dimensions (shape=", input_shape, ").");
NVTE_CHECK(output_shape.size() == input_shape.size(), "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
NVTE_CHECK(input_shape[0] == output_shape[1], "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
NVTE_CHECK(input_shape[1] == output_shape[0], "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
for (size_t i = 2; i < input_shape.size(); ++i) {
NVTE_CHECK(input_shape[i] == output_shape[i], "Input tensor (shape=", input_shape,
") and output tensor (shape=", output_shape, ") do not match.");
}
// Reinterpret tensors as 3D tensors of bytes
const size_t dim0 = output_shape[0];
const size_t dim1 = output_shape[1];
size_t dim2 = 1;
for (size_t i = 2; i < output_shape.size(); ++i) {
dim2 *= output_shape[i];
}
dim2 = get_buffer_size_bytes(dim2, output.dtype());
// Choose kernel config with performance heuristics
const size_t sm_count = static_cast<size_t>(cuda::sm_count());
KernelConfig config(dim0, dim1, dim2, sm_count, 1);
if (rtc::is_enabled()) {
auto try_config = [&](size_t vector_size) {
KernelConfig new_config(dim0, dim1, dim2, sm_count, vector_size);
if (new_config < config) {
config = new_config;
}
};
try_config(16);
try_config(8);
try_config(4);
try_config(2);
}
const size_t vector_size = config.vector_size;
// Launch kernel
if (vector_size == 1) {
// General kernel
swap_first_dims_untuned_kernel<<<config.num_blocks, block_size, 0, stream>>>(
static_cast<const uint8_t *>(input.data.dptr), static_cast<uint8_t *>(output.data.dptr),
dim0, dim1, dim2);
NVTE_CHECK_CUDA(cudaGetLastError());
} else {
// Compile NVRTC kernel if needed
auto &rtc_manager = rtc::KernelManager::instance();
const std::string kernel_label =
concat_strings("swap_first_dims,vector_size=", vector_size,
",single_load_store=", config.single_load_store);
if (!rtc_manager.is_compiled(kernel_label)) {
std::string code = string_code_transpose_rtc_swap_first_dims_cu;
code = regex_replace(code, "__VECTOR_SIZE__", vector_size);
code = regex_replace(code, "__BLOCK_SIZE__", block_size);
code =
regex_replace(code, "__SINGLE_LOAD_STORE__", static_cast<int>(config.single_load_store));
rtc_manager.compile(kernel_label, "swap_first_dims_kernel", code,
"transformer_engine/common/transpose/rtc/swap_first_dims.cu");
}
// Launch NVRTC kernel
rtc_manager.launch(kernel_label, config.num_blocks, block_size, 0, stream, input.data.dptr,
output.data.dptr, dim0, dim1, dim2 / vector_size);
}
}
} // namespace transformer_engine
void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_swap_first_dims);
using namespace transformer_engine;
swap_first_dims(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), stream);
}
...@@ -29,14 +29,8 @@ ...@@ -29,14 +29,8 @@
namespace transformer_engine { namespace transformer_engine {
template <typename T1, typename T2>
__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) {
return DIVUP(static_cast<uint64_t>(N), static_cast<uint64_t>(M)) * M;
}
namespace gated_kernels { namespace gated_kernels {
constexpr size_t ALIGNMENT_SIZE = 128;
constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128; constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 512; constexpr size_t THREADS_PER_CHUNK = 512;
...@@ -67,30 +61,31 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -67,30 +61,31 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const float *const scale_ptr, const size_t rows, const size_t cols) { const float *const scale_ptr, const size_t rows, const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
const int thread_offset_Y = tid_Y; const size_t thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X; const size_t thread_offset_X = tid_X;
float amax = 0; float amax = 0;
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
extern __shared__ char dshmem_unaligned[]; extern __shared__ char dynamic_shmem[];
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned); uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
const uint64_t dshmem_aligned_as_uint = // Manually align dynamic SHMEM per TMA requirements using padding
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; // __align__(128) Does not guarantee the pointer to be aligned!
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint); uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X;
constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in = constexpr size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out = constexpr size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0;
...@@ -99,8 +94,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -99,8 +94,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t in_mem = in_act_mem + in_gate_mem;
constexpr size_t out_act_mem = buff_size_aligned_out; constexpr size_t out_act_mem = buff_size_aligned_out;
// const size_t in_transaction_size = grad_mem + in_mem;
constexpr size_t in_transaction_size = buff_elems * sizeof(IType); constexpr size_t in_transaction_size = buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
...@@ -141,12 +134,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -141,12 +134,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int it = 0; it < ITERATIONS; ++it) { for (int it = 0; it < ITERATIONS; ++it) {
const int buff = it % BUFFERS_NUM; const size_t buff = it % BUFFERS_NUM;
const int next_it = it + 1; const size_t next_it = it + 1;
if (next_it < ITERATIONS) { if (next_it < ITERATIONS) {
const int next_buff = next_it % BUFFERS_NUM; const size_t next_buff = next_it % BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
copy_2d_to_sharedx3( copy_2d_to_sharedx3(
&in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y,
...@@ -174,10 +167,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -174,10 +167,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X; const size_t shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]); float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]); float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]);
...@@ -220,8 +213,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -220,8 +213,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
// dGeLU // dGeLU
ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x,
...@@ -271,12 +264,36 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -271,12 +264,36 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
} }
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__
namespace mxfp8_kernel {
constexpr size_t CHUNK_DIM_Y = 64;
constexpr size_t CHUNK_DIM_X = 64;
constexpr size_t THREADS_PER_CHUNK_COLWISE = 128;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X;
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t BUFF_DIM_Y = 32;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
#ifndef __HIP_PLATFORM_AMD__ #ifndef __HIP_PLATFORM_AMD__
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &), typename IType, typename OType, float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
size_t SCALE_DIM_Y, size_t SCALE_DIM_X> bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad,
const __grid_constant__ CUtensorMap tensor_map_input_act, const __grid_constant__ CUtensorMap tensor_map_input_act,
...@@ -289,43 +306,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -289,43 +306,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) { const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; using IType2 = typename ptx::FPx2<IType>;
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; using OType2 = typename ptx::FPx2<OType>;
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING;
constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING);
// # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension.
constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X);
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128
const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X;
const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X;
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t thread_offset_Y_colwise = tid_Y_colwise;
const size_t thread_offset_X_colwise = tid_X_colwise;
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise;
const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const int thread_offset_Y = tid_Y; const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols);
const int thread_offset_X = tid_X; const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
extern __shared__ char dshmem_unaligned[]; const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X;
const uint64_t dshmem_unaligned_as_uint = reinterpret_cast<uint64_t>(dshmem_unaligned); const size_t gate_scale_idx_offset_colwise = cols;
const uint64_t dshmem_aligned_as_uint =
DIVUP(dshmem_unaligned_as_uint, static_cast<uint64_t>(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE;
char *dshmem = reinterpret_cast<char *>(dshmem_aligned_as_uint);
const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; // helps resolving bank conflicts in shmem
const size_t buff_elems_total = BUFFERS_NUM * buff_elems; const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const size_t buff_size_aligned_in = const int bank_group = thread_lane / THREADS_PER_BANK;
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE;
const size_t buff_size_aligned_out = constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1;
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X];
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
...@@ -334,12 +381,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -334,12 +381,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const size_t in_mem = in_act_mem + in_gate_mem; const size_t in_mem = in_act_mem + in_gate_mem;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
const size_t out_mem = out_act_mem + out_gate_mem; const size_t out_mem = out_act_mem + out_gate_mem;
// const size_t in_transaction_size = grad_mem + in_mem;
const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_grad_sh = reinterpret_cast<IType *>(dshmem); IType *in_grad_sh = reinterpret_cast<IType *>(dshmem);
IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem); IType *in_act_sh = reinterpret_cast<IType *>(dshmem + grad_mem);
...@@ -351,375 +395,496 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -351,375 +395,496 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
OType *out_act_colwise_sh = out_act_rowwise_sh; OType *out_act_colwise_sh = out_act_rowwise_sh;
OType *out_gate_colwise_sh = out_gate_rowwise_sh; OType *out_gate_colwise_sh = out_gate_rowwise_sh;
if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { if constexpr (ROWWISE_SCALING && COLWISE_SCALING) {
out_act_colwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem); out_act_colwise_sh = reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem);
out_gate_colwise_sh = out_gate_colwise_sh =
reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem + out_act_mem); reinterpret_cast<OType *>(dshmem + grad_mem + in_mem + out_mem + out_act_mem);
} }
const uint64_t *TMAP_grad_in = reinterpret_cast<const uint64_t *>(&tensor_map_grad); IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations
const uint64_t *TMAP_in_act = reinterpret_cast<const uint64_t *>(&tensor_map_input_act); IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values
const uint64_t *TMAP_in_gate = reinterpret_cast<const uint64_t *>(&tensor_map_input_gate);
const uint64_t *TMAP_output_act_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise);
const uint64_t *TMAP_output_gate_rowwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise);
const uint64_t *TMAP_output_act_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise);
const uint64_t *TMAP_output_gate_colwise =
reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise);
__shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier. // Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init #pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[ITERATIONS]; __shared__ alignas(8) uint64_t mbar[STAGES];
const bool is_master_thread = (threadIdx.x == 0); initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
if (is_master_thread) {
// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate.
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK);
}
ptx::fence_proxy_async_shared_cta();
}
// Syncthreads so initialized barrier is visible to all threads.
__syncthreads();
int parity = 0; int parity = 0;
// Prefetch data of the first stage if constexpr (IS_DGATED) {
if (is_master_thread) { copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y,
// Initiate bulk tensor copy &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
// Grad &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
if constexpr (IS_DGATED) { shmem_buff_size, &mbar[0], is_master_thread);
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_grad_sh[0]),
TMAP_grad_in, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
}
// Act
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_act_sh[0]),
TMAP_in_act, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Gate
ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast<uint64_t *>(&in_gate_sh[0]),
TMAP_in_gate, chunk_offset_X, chunk_offset_Y,
&mbar[0]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size);
} else { } else {
// Other threads just arrive copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y,
ptx::mbarrier_arrive(&mbar[0]); &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y,
shmem_buff_size, &mbar[0], is_master_thread);
} }
#pragma unroll #pragma unroll
for (int it = 0; it < ITERATIONS; ++it) { for (int stage = 0; stage < STAGES; ++stage) {
const int buff = it % BUFFERS_NUM; const size_t buff = stage % BUFFS_NUM;
const int next_it = it + 1; const size_t next_stage = stage + 1;
const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_it < ITERATIONS) {
if (is_master_thread) { if (next_stage < STAGES) {
const int next_buff = next_it % BUFFERS_NUM; // Wait for TMA transfer to have finished reading shared memory.
const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; // I.e. the buffer is ready to be written to
const int chunk_it_offset_x = chunk_offset_X; ptx::cp_async_bulk_wait_group_read<1>();
// Initiate bulk tensor copy
if constexpr (IS_DGATED) { const size_t next_buff = next_stage % BUFFS_NUM;
// Grad const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
ptx::cp_async_bulk_tensor_2d_global_to_shared( const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
reinterpret_cast<uint64_t *>(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, const size_t global_offset_X = block_offset_X;
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); const size_t next_buff_offset = next_buff * BUFF_DIM;
} if constexpr (IS_DGATED) {
// Act copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X,
ptx::cp_async_bulk_tensor_2d_global_to_shared( global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act,
reinterpret_cast<uint64_t *>(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset],
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); &tensor_map_input_gate, global_offset_X, global_offset_Y,
// Gate shmem_buff_size, &mbar[next_stage], is_master_thread);
ptx::cp_async_bulk_tensor_2d_global_to_shared(
reinterpret_cast<uint64_t *>(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate,
chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]);
// Arrive on the barrier and tell how many bytes are expected to come in.
ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size);
} else { } else {
// Other threads just arrive copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X,
ptx::mbarrier_arrive(&mbar[next_it]); global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
} }
} }
ptx::fence_proxy_async_shared_cta(); ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived // Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[it], parity); ptx::mbarrier_wait_parity(&mbar[stage], parity);
IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems;
IType *in_act_sh_curr = in_act_sh + buff * buff_elems;
IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems;
OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems;
OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems;
OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems;
OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems;
// Assuming one iteration covers exactly 32 rows
const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it;
const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y;
float after_dact_reg[BUFFER_STAGES_NUM]; if constexpr (COLWISE_SCALING) {
float after_dgate_reg[BUFFER_STAGES_NUM]; const size_t shmem_offset_base_colwise =
float thread_Y_mx_block_amax = 0.0f; buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise;
float thread_Y_mx_block_amax_gate = 0.0f; float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE];
float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE];
// 1. Read/Compute elements. Find MXFP8-block AMAX
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t shmem_offset_colwise =
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
const int shmem_offset_x = thread_offset_X;
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x;
const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = (row >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float act_elt = static_cast<float>(in_act_sh_curr[shmem_idx]); float act_elt = static_cast<float>(in_act_sh[shmem_offset_colwise]);
float gate_elt = static_cast<float>(in_gate_sh_curr[shmem_idx]); float gate_elt = static_cast<float>(in_gate_sh[shmem_offset_colwise]);
float after_act_elt;
if constexpr (IS_DGATED) { float after_gate_elt;
float grad_elt = static_cast<float>(in_grad_sh_curr[shmem_idx]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) { if constexpr (IS_DGATED) {
const float s = sigmoidf(x); float grad_elt = static_cast<float>(in_grad_sh[shmem_offset_colwise]);
act_x = x * s; const float x = act_elt;
dact_x = x * s * (1 - s) + s; float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
} else { } else {
act_x = ActOP(x, {}); after_act_elt = ActOP(act_elt, {}) * gate_elt;
dact_x = DActOP(x, {}); }
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
} }
after_dact_reg[stage] = dact_x * grad_elt * gate_elt;
after_dgate_reg[stage] = act_x * grad_elt;
} else {
after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt;
}
if constexpr (USE_ROWWISE_SCALING) { after_act_colwise[i] = after_act_elt;
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// dgate after_gate_colwise[i] = after_gate_elt;
float amax = fabsf(after_dgate_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax);
const e8m0_t biased_exponent_X =
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X);
out_gate_rowwise_sh_curr[shmem_idx] =
static_cast<OType>(scale_reciprocal_X * after_dgate_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
}
} }
float amax = fabsf(after_dact_reg[stage]);
const float mx_block_X_amax = warp_reduce_max_broadcast(amax); // Cache computed activations to avoid computing them again in the 2nd pass along another dimension
const e8m0_t biased_exponent_X = if constexpr (IS_CACHED_ACT_OP) {
float_to_e8m0(mx_block_X_amax * Quantized_Limits<OType>::max_norm_rcp); cached_act_sh[shmem_offset_colwise] = static_cast<IType>(after_act_elt);
const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); if constexpr (IS_DGATED) {
cached_gate_sh[shmem_offset_colwise] = static_cast<IType>(after_gate_elt);
out_act_rowwise_sh_curr[shmem_idx] = }
static_cast<OType>(scale_reciprocal_X * after_dact_reg[stage]);
// Only single thread writes the computed scaling factor
if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y;
const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent_X;
} }
}
if constexpr (USE_COLWISE_SCALING) { const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
__builtin_assume(thread_Y_mx_block_amax >= 0); const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
__builtin_assume(thread_Y_mx_block_amax_gate >= 0);
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); if (!out_of_bounds) {
if constexpr (IS_DGATED) { thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
thread_Y_mx_block_amax_gate = if constexpr (IS_DGATED) {
fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
} }
} }
}
if constexpr (USE_COLWISE_SCALING) { if constexpr (ONLY_COLWISE_SCALING) {
const bool row_out_of_bounds = (row_base >= rows); // Threads, whose id along Y-dim is 0, don't need to store to shared memory,
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); // as they manage the columwise reduction of the amax
if (tid_Y_colwise > 0) {
if constexpr (IS_DGATED) { subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act;
// Colwise max reduction of the amax element
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate;
} }
__syncthreads(); __syncthreads();
if (tid_Y == 0) { if (tid_Y_colwise == 0) {
#pragma unroll #pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) {
thread_Y_mx_block_amax_gate = const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise];
fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); __builtin_assume(thread_amax_act >= 0);
__builtin_assume(other_thread_amax >= 0);
thread_amax_act = fmaxf(thread_amax_act, other_thread_amax);
} }
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act;
} }
__syncthreads(); __syncthreads();
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax // All threads read the reduced amax (ACT)
thread_amax_act = subamax_colwise_buff[0][tid_X_colwise];
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section if constexpr (IS_DGATED) {
if constexpr (!USE_ROWWISE_SCALING) { // Make sure the previous read of the ACT values has been completed,
__builtin_assume(mx_block_Y_amax >= 0); // so the data are not rewritten
__syncthreads();
if (tid_Y_colwise > 0) {
subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate;
}
__syncthreads();
if (tid_Y_colwise == 0) {
#pragma unroll
for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) {
const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise];
__builtin_assume(thread_amax_gate >= 0);
__builtin_assume(other_thread_amax >= 0);
thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax);
}
subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate;
}
__syncthreads();
// All threads read the reduced amax (GATE)
thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise];
} }
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent_act =
ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows;
const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise;
if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
scales_colwise[scale_idx] = biased_exponent_act;
}
float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
float block_scale_inverse_gate;
const e8m0_t biased_exponent = if constexpr (IS_DGATED) {
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp); const e8m0_t biased_exponent_gate =
const float scale_reciprocal = exp2f_rcp(biased_exponent); ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
// const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2;
// Only single thread writes the computed scaling factor const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise;
// Also assuming one iteration covers exactly 32 rows if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) {
if ((tid_Y == 0) && !out_of_bounds) { scales_colwise[scale_idx_gate] = biased_exponent_gate;
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols;
const int scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
} }
block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate);
}
// 3. Scale elements
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; const size_t shmem_offset_elt =
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X;
const int shmem_offset_x = thread_offset_X; if constexpr (IS_DGATED) {
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; OType2 out_pair;
ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]};
out_gate_colwise_sh_curr[shmem_idx] = const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act,
static_cast<OType>(scale_reciprocal * after_dgate_reg[stage]); block_scale_inverse_gate};
ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair);
out_act_colwise_sh[shmem_offset_elt] = out_pair.x;
out_gate_colwise_sh[shmem_offset_elt] = out_pair.y;
} else {
const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i];
out_act_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out_act);
} }
} }
// Colwise max reduction of the amax element }
if (tid_Y > 0) {
stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; if constexpr (ROWWISE_SCALING) {
} const size_t shmem_offset_base_rowwise =
__syncthreads(); buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
if (tid_Y == 0) {
float thread_amax_act = 0.0f;
float thread_amax_gate = 0.0f;
Vec<IType, PACK_SIZE> in_cached_act[WAVES];
Vec<IType, PACK_SIZE> in_cached_gate[WAVES];
float after_act_rowwise[SCALE_DIM_X];
float after_gate_rowwise[SCALE_DIM_X];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x_act = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
IType2 thread_amax_2x_gate = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]);
}
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll #pragma unroll
for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { for (int e = 0; e < PACK_SIZE; ++e) {
thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e]));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e]));
}
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e],
in_cached_act[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act);
if constexpr (IS_DGATED) {
const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e],
in_cached_gate[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate);
}
}
}
}
} }
stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax if constexpr (!std::is_same_v<IType, float>) {
} thread_amax_act = static_cast<float>(
__syncthreads(); __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y)));
if constexpr (IS_DGATED) {
thread_amax_gate = static_cast<float>(
__hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y)));
}
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in_grad;
Vec<IType, PACK_SIZE> in_act;
Vec<IType, PACK_SIZE> in_gate;
const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax in_act.load_from(&in_act_sh[shmem_offset_rowwise]);
in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]);
}
// For the scaling along both dimensions, the thread amax is already computed in ROWWISE section #pragma unroll
if constexpr (!USE_ROWWISE_SCALING) { for (int e = 0; e < PACK_SIZE; ++e) {
__builtin_assume(mx_block_Y_amax >= 0); const int j = w * PACK_SIZE + e;
float act_elt = static_cast<float>(in_act.data.elt[e]);
float gate_elt = static_cast<float>(in_gate.data.elt[e]);
float after_act_elt;
float after_gate_elt;
if constexpr (IS_DGATED) {
float grad_elt = static_cast<float>(in_grad.data.elt[e]);
const float x = act_elt;
float act_x;
float dact_x;
if constexpr ((ActOP == &silu<fp32, fp32>) && (DActOP == &dsilu<fp32, fp32>)) {
const float s = sigmoidf(x);
act_x = x * s;
dact_x = x * s * (1 - s) + s;
} else {
act_x = ActOP(x, {});
dact_x = DActOP(x, {});
}
after_act_elt = dact_x * grad_elt * gate_elt;
after_gate_elt = act_x * grad_elt;
after_act_rowwise[j] = after_act_elt;
after_gate_rowwise[j] = after_gate_elt;
} else {
after_act_elt = ActOP(act_elt, {}) * gate_elt;
after_act_rowwise[j] = after_act_elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
after_act_elt = static_cast<float>(static_cast<IType>(after_act_elt));
if constexpr (IS_DGATED) {
after_gate_elt = static_cast<float>(static_cast<IType>(after_gate_elt));
}
}
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) {
thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt));
if constexpr (IS_DGATED) {
thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt));
}
}
}
}
} }
const e8m0_t biased_exponent = // 2. Compute E8M0 scaling factor
float_to_e8m0(mx_block_Y_amax * Quantized_Limits<OType>::max_norm_rcp); const e8m0_t biased_exponent_act =
const float scale_reciprocal = exp2f_rcp(biased_exponent); ptx::float_to_e8m0(thread_amax_act * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
// Only single thread writes the computed scaling factor const size_t stage_scales_offset_X = scales_offset_X_rowwise;
// Also assuming one iteration covers exactly 32 rows const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
if ((tid_Y == 0) && !out_of_bounds) { const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows;
const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise;
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; if (!out_of_bounds_rowwise) {
const int scale_idx = scales_rowwise[scale_idx] = biased_exponent_act;
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
} }
const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act);
const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act,
block_scale_inverse_act};
float block_scale_inverse_gate;
ptx::floatx2 block_scale_inverse_2x_gate;
if constexpr (IS_DGATED) {
const e8m0_t biased_exponent_gate =
ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits<OType>::max_norm_rcp);
const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise;
if (!out_of_bounds_rowwise) {
scales_rowwise[scale_idx_gate] = biased_exponent_gate;
}
block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate);
block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate};
}
// 3. Scale elements
#pragma unroll #pragma unroll
for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { for (int w = 0; w < WAVES; ++w) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; Vec<OType2, PACK_SIZE / 2> out_act;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; Vec<OType2, PACK_SIZE / 2> out_gate;
const int shmem_offset_x = thread_offset_X; #pragma unroll
const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in_act;
out_act_colwise_sh_curr[shmem_idx] = OType2 &out_act_pair = reinterpret_cast<OType2 &>(out_act.data.elt[e]);
static_cast<OType>(scale_reciprocal * after_dact_reg[stage]);
if constexpr (IS_CACHED_ACT_OP) {
in_act.x = in_cached_act[w].data.elt[2 * e];
in_act.y = in_cached_act[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in_act.x = after_act_rowwise[j];
in_act.y = after_act_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act);
if constexpr (IS_DGATED) {
IType2 in_gate;
OType2 &out_gate_pair = reinterpret_cast<OType2 &>(out_gate.data.elt[e]);
if constexpr (IS_CACHED_ACT_OP) {
in_gate.x = in_cached_gate[w].data.elt[2 * e];
in_gate.y = in_cached_gate[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in_gate.x = after_gate_rowwise[j];
in_gate.y = after_gate_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate);
}
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]);
if constexpr (IS_DGATED) {
out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]);
}
} }
} // endif USE_COLWISE_SCALING }
// Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) // Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta(); ptx::fence_proxy_async_shared_cta();
__syncthreads(); __syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine. // After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t global_offset_X = block_offset_X;
const size_t buff_offset = buff * BUFF_DIM;
// dGeLU if constexpr (ROWWISE_SCALING) {
if constexpr (USE_ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_act_rowwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_act_rowwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_rowwise_sh[buff_offset]));
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_rowwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_gate_rowwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_rowwise_sh[buff_offset]));
} }
} }
if constexpr (COLWISE_SCALING) {
// dGeLU
if constexpr (USE_COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_act_colwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_act_colwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_act_colwise_sh[buff_offset]));
if constexpr (IS_DGATED) { if constexpr (IS_DGATED) {
// dGate
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast<const uint64_t *>(&tensor_map_output_gate_colwise), global_offset_X,
reinterpret_cast<uint64_t *>(out_gate_colwise_sh_curr)); global_offset_Y, reinterpret_cast<uint64_t *>(&out_gate_colwise_sh[buff_offset]));
} }
} }
// Create a "bulk async-group" out of the previous bulk copy operation. // Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group(); ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<BUFFERS_NUM - 1>();
} }
} }
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
// Destroy the barriers. This invalidates the memory region of the barrier. parity ^= 1;
// If further computations were to take place in the kernel, this allows the destroy_barriers<STAGES>(mbar, is_master_thread);
// memory location of the shared memory barrier to be reused.
if (is_master_thread) {
#pragma unroll
for (int it = 0; it < ITERATIONS; ++it) {
ptx::mbarrier_invalid(&mbar[it]);
}
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
} // namespace mxfp8_kernel
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &), template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
...@@ -728,6 +893,8 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -728,6 +893,8 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
checkCuDriverContext(stream);
if (output->has_data()) { if (output->has_data()) {
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
} }
...@@ -780,17 +947,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu ...@@ -780,17 +947,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in = const size_t buff_size_aligned_in =
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
const size_t buff_size_aligned_out = const size_t buff_size_aligned_out =
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in; const size_t in_act_mem = buff_size_aligned_in;
const size_t in_gate_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in;
const size_t out_act_mem = buff_size_aligned_out; const size_t out_act_mem = buff_size_aligned_out;
const size_t out_gate_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out;
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) +
const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT;
(out_act_mem + out_gate_mem); // + mbar_mem;
cudaFuncSetAttribute( cudaFuncSetAttribute(
cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>, cast_fp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType>,
...@@ -811,7 +977,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -811,7 +977,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
cudaStream_t stream) { cudaStream_t stream) {
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
checkCuDriverContext(stream);
const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_ROWWISE_SCALING = output->has_data();
const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data();
...@@ -822,16 +990,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -822,16 +990,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated.");
} }
// TODO: Make more general ScalingType scaling_type;
const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) {
const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; scaling_type = ScalingType::ROWWISE;
} else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) {
scaling_type = ScalingType::COLWISE;
} else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
const size_t rows = gated_input.flat_first_dim(); const size_t rows = gated_input.flat_first_dim();
const size_t cols = gated_input.flat_last_dim() / 2; const size_t cols = gated_input.flat_last_dim() / 2;
const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y;
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X;
constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM;
const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X);
constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE;
constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE;
const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE)
? THREADS_PER_CHUNK_COLWISE
: THREADS_PER_CHUNK_NON_COLWISE;
const dim3 grid(blocks_X, blocks_Y);
const dim3 block_size(THREADS_PER_CHUNK);
size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1;
size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1;
...@@ -841,94 +1027,122 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out ...@@ -841,94 +1027,122 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
e8m0_t *const scales_colwise_ptr = e8m0_t *const scales_colwise_ptr =
USE_COLWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr; USE_COLWISE_SCALING ? reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr) : nullptr;
const dim3 block_dim(THREADS_PER_CHUNK); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
const dim3 grid_dim(blocks_X, blocks_Y); gated_input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_grad{};
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{};
alignas(64) CUtensorMap tensor_map_output_act_rowwise{};
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{};
alignas(64) CUtensorMap tensor_map_output_act_colwise{};
alignas(64) CUtensorMap tensor_map_output_gate_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( if constexpr (IS_DGATED) {
scale_dim_Y_colwise, SCALE_DIM_Y, create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( cols, 0, input_type_bit_size);
scale_dim_X_rowwise, SCALE_DIM_X, }
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
gated_input.dtype(), IType, const uint32_t tensor_stride_elems = output_cols;
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y,
output->dtype(), OType, BUFF_DIM_X, cols * 2, 0, input_type_bit_size);
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y,
alignas(64) CUtensorMap tensor_map_grad{}; BUFF_DIM_X, cols * 2, cols, input_type_bit_size);
alignas(64) CUtensorMap tensor_map_input_act{};
alignas(64) CUtensorMap tensor_map_input_gate{}; if (USE_ROWWISE_SCALING) {
alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols,
alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
alignas(64) CUtensorMap tensor_map_output_act_colwise{}; output_type_bit_size);
alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols,
BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
if constexpr (IS_DGATED) { output_type_bit_size);
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, }
SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype()));
} if (USE_COLWISE_SCALING) {
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols,
const uint32_t tensor_stride_elems = output_cols; BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0,
create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, output_type_bit_size);
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows,
typeToNumBits(gated_input.dtype())); cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols,
create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, output_type_bit_size);
SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, }
typeToNumBits(gated_input.dtype()));
const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X;
if (USE_ROWWISE_SCALING) { const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, const size_t buff_size_aligned_in =
typeToNumBits(output->dtype())); DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, const size_t buff_size_aligned_out =
SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
typeToNumBits(output->dtype()));
} const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0);
const size_t in_act_mem = buff_size_aligned_in;
if (USE_COLWISE_SCALING) { const size_t in_gate_mem = buff_size_aligned_in;
create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, const size_t in_mem = grad_mem + in_act_mem + in_gate_mem;
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems,
0, typeToNumBits(output->dtype())); const size_t out_act_mem = buff_size_aligned_out;
create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0);
rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, size_t out_mem = out_act_mem + out_gate_mem;
cols, typeToNumBits(output->dtype())); if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
}
const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X;
const size_t buff_size_aligned_in = switch (scaling_type) {
DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; case ScalingType::ROWWISE:
const size_t buff_size_aligned_out = cudaFuncSetAttribute(
DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, false,
const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); THREADS_PER_CHUNK_NON_COLWISE>,
const size_t in_act_mem = buff_size_aligned_in; cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
const size_t in_gate_mem = buff_size_aligned_in;
const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, false, THREADS_PER_CHUNK_NON_COLWISE>
const size_t out_act_mem = buff_size_aligned_out; <<<grid, block_size, shmem_size, stream>>>(
const size_t out_gate_mem = buff_size_aligned_out;
size_t out_mem = out_act_mem + out_gate_mem;
if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; }
// const size_t mbar_mem = ITERATIONS * sizeof(uint64_t);
// const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem;
const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem;
cudaFuncSetAttribute(
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
SCALE_DIM_Y, SCALE_DIM_X>
<<<grid_dim, block_dim, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);); // NOLINT(*) scale_stride_colwise);
); // NOLINT(*) break;
); // NOLINT(*) case ScalingType::COLWISE:
); // NOLINT(*) cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, false, true,
THREADS_PER_CHUNK_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
false, true, THREADS_PER_CHUNK_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType,
OType, true, true,
THREADS_PER_CHUNK_NON_COLWISE>,
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size);
mxfp8_kernel::cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
true, true, THREADS_PER_CHUNK_NON_COLWISE>
<<<grid, block_size, shmem_size, stream>>>(
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
tensor_map_output_act_colwise, tensor_map_output_gate_colwise,
scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}); // NOLINT(*)
); // NOLINT(*)
#endif #endif
} }
...@@ -1009,7 +1223,6 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP ...@@ -1009,7 +1223,6 @@ template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP
float (*DActOP)(float, const ParamOP &)> float (*DActOP)(float, const ParamOP &)>
void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output,
cudaStream_t stream) { cudaStream_t stream) {
checkCuDriverContext(stream);
constexpr bool allow_empty = false; constexpr bool allow_empty = false;
CheckInputTensor(gated_input, "gated_input"); CheckInputTensor(gated_input, "gated_input");
CheckOutputTensor(*output, "output", allow_empty); CheckOutputTensor(*output, "output", allow_empty);
......
...@@ -30,37 +30,25 @@ ...@@ -30,37 +30,25 @@
namespace transformer_engine { namespace transformer_engine {
constexpr size_t MXFP8_CHUNK_DIM_Y = 64; namespace mxfp8_kernel {
constexpr size_t MXFP8_CHUNK_DIM_X = 64;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; constexpr size_t SCALE_DIM_X = 32;
constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; constexpr size_t BUFFS_NUM = 2;
constexpr size_t MXFP8_BUFFERS_NUM = 2; constexpr size_t PACK_SIZE = 4;
constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM);
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t ELEMS_PER_THREAD = 16; constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported
constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 // Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64
constexpr size_t THREADS_PER_CHUNK_X_ROWWISE =
MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16
constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE =
MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4
constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64
constexpr size_t MXFP8_BUFF_STAGES_NUM =
MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
#ifndef __HIP_PLATFORM_AMD__
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y, float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
size_t SCALE_DIM_X> bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK>
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __global__ void __launch_bounds__(THREADS_PER_CHUNK)
cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_act_input,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise, const __grid_constant__ CUtensorMap tensor_map_output_rowwise,
...@@ -70,201 +58,348 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -70,201 +58,348 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t rows, const size_t cols, const size_t scale_stride_rowwise,
const size_t scale_stride_colwise) { const size_t scale_stride_colwise) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT;
if (noop != nullptr && noop[0] == 1.0f) return; constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS;
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
} }
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X;
const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X;
const size_t tid_Y_colwise = 0;
const size_t tid_X_colwise = threadIdx.x;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t thread_offset_Y_colwise = tid_Y_colwise;
const size_t thread_offset_X_colwise = tid_X_colwise;
const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise;
const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise;
const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise;
const bool col_out_of_bounds_colwise = (col_base_colwise >= cols);
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0);
extern __shared__ char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) &
~(static_cast<uintptr_t>(TMA_SHMEM_ALIGNMENT - 1));
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; const bool is_master_thread = (threadIdx.x == 0);
constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1;
constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; float partial_dbias_colwise = 0.0f;
float thread_dbias_rowwise[SCALE_DIM_X];
constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y =
SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_ROWWISE_PER_BLOCK_X =
SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32
constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_Y =
SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1
constexpr size_t SCALES_COLWISE_PER_BLOCK_X =
SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X;
const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y;
const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X;
const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y;
const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X;
const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE;
const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE;
// const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE;
const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE;
const int thread_offset_Y = tid_rowwise_Y;
const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD;
// const int thread_offset_X_colwise = tid_colwise_X;
const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y;
const int dbias_rowwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise;
const int dbias_colwise_offset_Y = blockIdx.y;
const int dbias_colwise_block_offset_X =
blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X;
const int dbias_stride = cols;
Vec<float, ELEMS_PER_THREAD> partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X];
float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X];
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
#pragma unroll #pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { for (int j = 0; j < SCALE_DIM_X; ++j) {
partial_dbias_rowwise[i].clear(); thread_dbias_rowwise[j] = 0.0f;
}
} else {
#pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) {
partial_dbias_colwise[i] = 0;
}
} }
} }
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned float block_amax = 0.0f;
__shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
__shared__ alignas(128)
OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0);
float block_amax = 0;
// Initialize shared memory barrier with the number of threads participating in the barrier. // Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init #pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; __shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<MXFP8_ITERATIONS, MXFP8_THREADS_PER_CHUNK>(mbar, is_master_thread); initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, is_master_thread);
int parity = 0; int parity = 0;
#pragma unroll
for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) {
const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X;
const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; if constexpr (IS_DACT) {
const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0],
&tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], is_master_thread);
}
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; #pragma unroll
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; for (int stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
const int scales_rowwise_chunk_offset_Y = if (next_stage < STAGES) {
scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; // Wait for TMA transfer to have finished reading shared memory.
const int scales_rowwise_chunk_offset_X = // I.e. the buffer is ready to be written to
scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; ptx::cp_async_bulk_wait_group_read<1>();
const int scales_colwise_chunk_offset_Y =
scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y;
const int scales_colwise_chunk_offset_X =
scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X;
#pragma unroll const size_t next_buff = next_stage % BUFFS_NUM;
for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const int chunk_stage_offset_X = chunk_offset_X; const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
&mbar[prefetch_buff], is_master_thread); is_master_thread);
} else { } else {
copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread);
is_master_thread);
} }
} }
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if (stage == STAGES - 1) {
cudaTriggerProgrammaticLaunchCompletion();
}
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll #pragma unroll
for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { for (int i = 0; i < BUFF_DIM_Y; ++i) {
const int buff = iter % MXFP8_BUFFERS_NUM; const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; in_colwise_IType[i] = in_sh[shmem_offset_colwise];
const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
if (next_iter < MXFP8_ITERATIONS) {
const int next_buff = next_iter % MXFP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size,
&mbar[next_iter], is_master_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread);
} }
} thread_amax = static_cast<float>(thread_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[shmem_offset_colwise]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
partial_dbias_colwise += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
ptx::fence_proxy_async_shared_cta(); if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows);
const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise);
if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
} else {
// If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt));
}
in_compute_colwise[i] = elt;
}
}
// Wait for the data to have arrived // 2. Compute E8M0 scaling factor
ptx::mbarrier_wait_parity(&mbar[iter], parity); const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
if constexpr (USE_ROWWISE_SCALING) { const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
Vec<IType, ELEMS_PER_THREAD> in; const size_t global_scales_offset_X = scales_offset_X_colwise;
Vec<IType, ELEMS_PER_THREAD> act_in; const size_t scale_idx =
Vec<OType, ELEMS_PER_THREAD> out_c; global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
scales_colwise[scale_idx] = biased_exponent;
const int iteration_scale_rowwise_offset_Y = const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll #pragma unroll
for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { for (int i = 0; i < SCALE_DIM_Y; ++i) {
const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; float in;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
const int shmem_offset_x = thread_offset_X_rowwise; in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const size_t row = row_base + shmem_offset_y; const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
const bool row_out_of_bounds = (row >= rows); out_colwise_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); if constexpr (ROWWISE_SCALING) {
if constexpr (IS_DACT) { const size_t shmem_offset_base_rowwise =
act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
} thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
float thread_amax = 0; // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
float in_compute[ELEMS_PER_THREAD]; Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll #pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) { for (int w = 0; w < WAVES; ++w) {
const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if (!out_of_bounds) {
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e],
in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
float elt = static_cast<float>(in.data.elt[j]); Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[shmem_offset_rowwise]);
}
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (IS_ACT) { if constexpr (IS_ACT) {
elt = OP(elt, {}); elt = OP(elt, {});
} }
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[j]); float act_in_elt = static_cast<float>(act_in.data.elt[e]);
elt *= OP(act_in_elt, {}); elt *= OP(act_in_elt, {});
} }
if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
if (!out_of_bounds) {
partial_dbias_rowwise[chunk_X].data.elt[j] += elt;
}
}
in_compute[j] = elt;
if constexpr (IS_ACT || IS_DACT) { // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
if constexpr (IS_DBIAS && (!COLWISE_SCALING)) {
thread_dbias_rowwise[j] += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
if constexpr (COMPUTE_ACTIVATIONS) {
const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows);
const bool swizzled_col_out_of_bounds =
(block_offset_X + swizzled_thread_idx >= cols);
const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds);
if (!out_of_bounds) { if (!out_of_bounds) {
thread_amax = fmaxf(thread_amax, fabsf(elt)); thread_amax = fmaxf(thread_amax, fabsf(elt));
} }
...@@ -272,197 +407,142 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) ...@@ -272,197 +407,142 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
// If no activation, elt is 0 so we can safely do this // If no activation, elt is 0 so we can safely do this
thread_amax = fmaxf(thread_amax, fabsf(elt)); thread_amax = fmaxf(thread_amax, fabsf(elt));
} }
in_compute_rowwise[j] = elt;
} }
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
const float subwarp_amax = subwarp_reduce_max_broadcast<SUBWARP_WIDTH>(thread_amax);
const e8m0_t biased_exponent =
float_to_e8m0(subwarp_amax * Quantized_Limits<OType>::max_norm_rcp);
// Only single thread writes the computed scaling factor
if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) {
const int global_scales_offset_Y =
iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y;
const int global_scales_offset_X =
scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE;
const int scale_idx =
global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
}
const float block_scale_inverse = exp2f_rcp(biased_exponent);
#pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) {
out_c.data.elt[j] = static_cast<OType>(in_compute[j] * block_scale_inverse);
}
out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]);
} }
} }
if constexpr (USE_COLWISE_SCALING) { // 2. Compute E8M0 scaling factor
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); const e8m0_t biased_exponent =
float in_compute[SCALE_DIM_Y]; ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const size_t stage_scales_offset_X = scales_offset_X_rowwise;
const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
scales_rowwise[scale_idx] = biased_exponent;
float amax = 0; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
#pragma unroll const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
for (int i = 0; i < SCALE_DIM_Y; ++i) {
const size_t row = row_base + i;
const bool row_out_of_bounds = (row >= rows);
const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds);
float elt = static_cast<float>(in_sh[buff][i][tid_colwise_X]); // 3. Scale elements
if constexpr (IS_ACT) { #pragma unroll
elt = OP(elt, {}); for (int w = 0; w < WAVES; ++w) {
} Vec<OType2, PACK_SIZE / 2> out;
if constexpr (IS_DACT) { #pragma unroll
float act_in_elt = static_cast<float>(act_in_sh[buff][i][tid_colwise_X]); for (int e = 0; e < PACK_SIZE / 2; ++e) {
elt *= OP(act_in_elt, {}); IType2 in;
} OType2 &out_pair = reinterpret_cast<OType2 &>(out.data.elt[e]);
if constexpr (IS_DBIAS) { if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
if (!out_of_bounds) { in = in_IType[w].data.elt[e];
partial_dbias_colwise[chunk_X] += elt; } else if constexpr (IS_CACHED_ACT_OP) {
} in.x = in_cached[w].data.elt[2 * e];
} in.y = in_cached[w].data.elt[2 * e + 1];
in_compute[i] = elt;
if constexpr (IS_ACT || IS_DACT) {
if (!out_of_bounds) {
amax = fmaxf(amax, fabsf(elt));
}
} else { } else {
// If no activation, elt is 0 so we can safely do this const int j = w * PACK_SIZE + 2 * e;
amax = fmaxf(amax, fabsf(elt)); in.x = in_compute_rowwise[j];
in.y = in_compute_rowwise[j + 1];
} }
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
} }
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_sh[shmem_offset_rowwise]);
}
}
__builtin_assume(block_amax >= 0); __builtin_assume(block_amax >= 0);
__builtin_assume(amax >= 0); __builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, amax); block_amax = fmaxf(block_amax, thread_amax);
const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits<OType>::max_norm_rcp);
const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; // Wait for shared memory writes to be visible to TMA engine.
const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; ptx::fence_proxy_async_shared_cta();
const int scale_idx = __syncthreads();
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; // After syncthreads, writes by all threads are visible to TMA engine.
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = exp2f_rcp(biased_exponent); // Initiate TMA transfer to copy shared memory to global memory
#pragma unroll if (is_master_thread) {
for (int i = 0; i < SCALE_DIM_Y; ++i) { const size_t global_offset_Y = block_offset_Y + stage_offset_Y;
out_colwise_sh[buff][i][tid_colwise_X] = const size_t global_offset_X = block_offset_X;
static_cast<OType>(in_compute[i] * block_scale_inverse); const size_t buff_offset = buff * BUFF_DIM;
}
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff_offset]));
} }
if constexpr (COLWISE_SCALING) {
// Wait for shared memory writes to be visible to TMA engine. ptx::cp_async_bulk_tensor_2d_shared_to_global(
ptx::fence_proxy_async_shared_cta(); reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
__syncthreads(); global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff_offset]));
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X;
if constexpr (USE_ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_rowwise_sh[buff]));
}
if constexpr (USE_COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_colwise_sh[buff]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
// Wait for TMA transfer to have finished reading shared memory.
ptx::cp_async_bulk_wait_group_read<MXFP8_PREFETCH_BUFFERS_NUM>();
} }
}
ptx::cp_async_bulk_wait_group_read<0>();
__syncthreads();
parity ^= 1; // Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
} }
if constexpr (IS_DBIAS) { parity ^= 1;
if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) {
constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X;
constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1;
constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE;
__shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD];
if (tid_rowwise_Y > 0) {
#pragma unroll
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) {
partial_dbias_rowwise[c].store_to(
&shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]);
}
}
__syncthreads();
if (tid_rowwise_Y == 0) { if constexpr (IS_DBIAS) {
#pragma unroll float thread_partial_dbias = 0.0f;
for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { if constexpr (COLWISE_SCALING) {
Vec<float, ELEMS_PER_THREAD> other_row_dbias; thread_partial_dbias = partial_dbias_colwise;
const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; } else {
const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
const int left_bound = dbias_rowwise_offset_X; constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1;
const size_t shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll #pragma unroll
for (int i = 0; i < Y; ++i) { for (int w = 0; w < WAVES; ++w) {
other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll #pragma unroll
for (int j = 0; j < ELEMS_PER_THREAD; ++j) { for (int e = 0; e < PACK_SIZE; ++e) {
partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; const int j = w * PACK_SIZE + e;
} const size_t shmem_elt_idx = swizzled_group_offset + e;
} partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
// Vectorized store when all elements are inside the boundaries
if (right_bound < cols) {
partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]);
} else if (left_bound < cols && right_bound >= cols) {
// Element-by-element store when some elements cross the boundaries
const int in_bound_elts_count = cols - left_bound;
partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0,
in_bound_elts_count);
}
} }
} }
} else { __syncthreads();
#pragma unroll #pragma unroll
for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { for (int i = 0; i < THREADS_Y; ++i) {
const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; // Add extra element offset per MXFP8 scaling block [1x32]
const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; const size_t scaling_block = threadIdx.x / SCALE_DIM_X;
const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); thread_partial_dbias +=
if (!col_out_of_bounds) { partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
dbias_workspace[dbias_offset] = partial_dbias_colwise[i];
}
} }
} }
const size_t dbias_stride = cols;
const size_t dbias_offset_Y = blockIdx.y;
const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x;
const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
}
} }
if (amax_ptr != nullptr) { if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block // Reduce the amax over the block
block_amax = reduce_max<MXFP8_THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id); block_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
} }
if (is_master_thread && amax_ptr != nullptr) { if (is_master_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax); atomicMaxFloat(amax_ptr, block_amax);
} }
destroy_barriers<MXFP8_ITERATIONS>(mbar, is_master_thread); destroy_barriers<STAGES>(mbar, is_master_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
} }
#endif // __HIP_PLATFORM_AMD__ #endif // __HIP_PLATFORM_AMD__
} // namespace mxfp8_kernel
constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_Y = 128;
constexpr size_t FP8_CHUNK_DIM_X = 128; constexpr size_t FP8_CHUNK_DIM_X = 128;
...@@ -492,19 +572,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -492,19 +572,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const size_t cols) { const size_t cols) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y;
const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X;
const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK;
const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK;
const int thread_offset_Y = tid_Y; const size_t thread_offset_Y = tid_Y;
const int thread_offset_X = tid_X; const size_t thread_offset_X = tid_X;
const int dbias_offset_Y = blockIdx.y + tid_Y; const size_t dbias_offset_Y = blockIdx.y + tid_Y;
const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X;
const bool col_out_of_bounds = my_column >= cols; const bool col_out_of_bounds = my_column >= cols;
const int dbias_stride = cols; const size_t dbias_stride = cols;
float partial_dbias = 0.f; float partial_dbias = 0.f;
...@@ -512,11 +592,14 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -512,11 +592,14 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT)
__shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT)
IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
__shared__ alignas(TMA_SHMEM_ALIGNMENT)
OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -528,13 +611,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -528,13 +611,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
int parity = 0; int parity = 0;
const int chunk_offset_Y = block_offset_Y; const size_t chunk_offset_Y = block_offset_Y;
const int chunk_offset_X = block_offset_X; const size_t chunk_offset_X = block_offset_X;
#pragma unroll #pragma unroll
for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) {
const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y;
const int chunk_stage_offset_X = chunk_offset_X; const size_t chunk_stage_offset_X = chunk_offset_X;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X,
chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input,
...@@ -549,13 +632,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -549,13 +632,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { for (int iter = 0; iter < FP8_ITERATIONS; ++iter) {
const int buff = iter % FP8_BUFFERS_NUM; const size_t buff = iter % FP8_BUFFERS_NUM;
const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM;
const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y;
if (next_iter < FP8_ITERATIONS) { if (next_iter < FP8_ITERATIONS) {
const int next_buff = next_iter % FP8_BUFFERS_NUM; const size_t next_buff = next_iter % FP8_BUFFERS_NUM;
const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x,
chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input,
...@@ -572,9 +655,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -572,9 +655,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
#pragma unroll #pragma unroll
for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) {
const int stage_offset_Y = stage; const size_t stage_offset_Y = stage;
const int shmem_offset_y = thread_offset_Y + stage_offset_Y; const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y;
const int shmem_offset_x = thread_offset_X; const size_t shmem_offset_x = thread_offset_X;
const size_t row = row_base + shmem_offset_y; const size_t row = row_base + shmem_offset_y;
const bool row_out_of_bounds = row >= rows; const bool row_out_of_bounds = row >= rows;
const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds;
...@@ -613,8 +696,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -613,8 +696,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
// Initiate TMA transfer to copy shared memory to global memory // Initiate TMA transfer to copy shared memory to global memory
if (is_master_thread) { if (is_master_thread) {
const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y;
const int chunk_it_offset_x = chunk_offset_X; const size_t chunk_it_offset_x = chunk_offset_X;
ptx::cp_async_bulk_tensor_2d_shared_to_global( ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x, reinterpret_cast<const uint64_t *>(&tensor_map_output), chunk_it_offset_x,
chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff])); chunk_it_offset_y, reinterpret_cast<uint64_t *>(&out_sh[buff]));
...@@ -632,8 +715,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) ...@@ -632,8 +715,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK)
parity ^= 1; parity ^= 1;
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
const int dbias_offset_X = my_column; const size_t dbias_offset_X = my_column;
const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X;
if (!col_out_of_bounds) { if (!col_out_of_bounds) {
dbias_workspace[dbias_offset] = partial_dbias; dbias_workspace[dbias_offset] = partial_dbias;
} }
...@@ -676,7 +759,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -676,7 +759,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK;
const IType *input = input_ptr + block_offset; const IType *input = input_ptr + block_offset;
OType *output = output_ptr + block_offset; OType *output = output_ptr + block_offset;
...@@ -684,11 +767,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -684,11 +767,11 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1;
// The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned
__shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM];
__shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM];
constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS;
constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS;
const bool is_master_thread = (threadIdx.x == 0); const bool is_master_thread = (threadIdx.x == 0);
...@@ -704,12 +787,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -704,12 +787,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#pragma unroll #pragma unroll
for (int iter = 0; iter < ITERATIONS; ++iter) { for (int iter = 0; iter < ITERATIONS; ++iter) {
const int buff = iter % SHMEM_BUFFERS; const size_t buff = iter % SHMEM_BUFFERS;
const int it_offset = iter * SHMEM_DIM; const size_t it_offset = iter * SHMEM_DIM;
const int next_iter = iter + 1; const size_t next_iter = iter + 1;
const int next_buff = next_iter % SHMEM_BUFFERS; const size_t next_buff = next_iter % SHMEM_BUFFERS;
const int next_iter_offset = next_iter * SHMEM_DIM; const size_t next_iter_offset = next_iter * SHMEM_DIM;
if (next_iter < ITERATIONS) { if (next_iter < ITERATIONS) {
copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN,
...@@ -723,7 +806,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -723,7 +806,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
#pragma unroll #pragma unroll
for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) {
const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x;
float elt = static_cast<float>(in_sh[buff][shmem_offset]); float elt = static_cast<float>(in_sh[buff][shmem_offset]);
if constexpr (IS_ACT) { if constexpr (IS_ACT) {
elt = OP(elt, {}); elt = OP(elt, {});
...@@ -776,12 +859,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -776,12 +859,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; constexpr size_t DBIAS_THREADS_PER_BLOCK = 256;
template <int nvec, typename OType> template <int nvec, typename OType>
__global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial,
const int cols) { const size_t rows, const size_t cols) {
using ComputeVec = Vec<float, nvec>; using ComputeVec = Vec<float, nvec>;
using OutputVec = Vec<OType, nvec>; using OutputVec = Vec<OType, nvec>;
const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x;
if (thread_id * nvec >= cols) { if (thread_id * nvec >= cols) {
return; return;
...@@ -812,8 +895,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) ...@@ -812,8 +895,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK)
template <typename IType> template <typename IType>
void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols,
cudaStream_t stream) { cudaStream_t stream) {
constexpr int reduce_dbias_store_bytes = 8; // stg.64 constexpr size_t reduce_dbias_store_bytes = 8; // stg.64
constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType);
NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape.");
const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec);
...@@ -934,9 +1017,11 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -934,9 +1017,11 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
#ifdef __HIP_PLATFORM_AMD__ #ifdef __HIP_PLATFORM_AMD__
assert(false); assert(false);
#else #else
using namespace mxfp8_kernel;
checkCuDriverContext(stream);
bool use_rowwise_scaling = output->has_data(); bool use_rowwise_scaling = output->has_data();
bool use_colwise_scaling = output->has_columnwise_data(); bool use_colwise_scaling = output->has_columnwise_data();
checkCuDriverContext(stream);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
...@@ -949,16 +1034,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -949,16 +1034,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
} }
CheckNoopTensor(*noop, "cast_noop"); CheckNoopTensor(*noop, "cast_noop");
// TODO: Make more general
const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1;
const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1;
const size_t rows = input.flat_first_dim(); const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim(); const size_t cols = input.flat_last_dim();
const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y);
const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT);
const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y);
const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64;
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X);
const dim3 grid(blocks_X, blocks_Y);
const size_t block_size = THREADS_PER_CHUNK;
const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1;
const size_t scale_stride_colwise = const size_t scale_stride_colwise =
...@@ -971,6 +1064,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -971,6 +1064,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dbias_rows = blocks_Y; const size_t dbias_rows = blocks_Y;
const size_t dbias_cols = cols; const size_t dbias_cols = cols;
ScalingType scaling_type;
if (use_rowwise_scaling && (!use_colwise_scaling)) {
scaling_type = ScalingType::ROWWISE;
} else if ((!use_rowwise_scaling) && use_colwise_scaling) {
scaling_type = ScalingType::COLWISE;
} else if (use_rowwise_scaling && use_colwise_scaling) {
scaling_type = ScalingType::BIDIMENSIONAL;
}
if constexpr (IS_DBIAS) { if constexpr (IS_DBIAS) {
NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias."); NVTE_CHECK(dbias->data.shape == std::vector<size_t>{cols}, "Wrong shape of DBias.");
...@@ -985,58 +1087,114 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -985,58 +1087,114 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr; float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr); float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const dim3 block(MXFP8_THREADS_PER_CHUNK); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
const dim3 grid(blocks_X, blocks_Y); input.dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X,
scale_dim_Y_colwise, SCALE_DIM_Y, cols, 0, input_type_bit_size);
TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH(
scale_dim_X_rowwise, SCALE_DIM_X, if constexpr (IS_DACT) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y,
input.dtype(), IType, BUFF_DIM_X, cols, 0, input_type_bit_size);
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( }
output->dtype(), OType,
if (use_rowwise_scaling) {
alignas(64) CUtensorMap tensor_map_input{}; create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y,
alignas(64) CUtensorMap tensor_map_act_input{}; BUFF_DIM_X, cols, 0, output_type_bit_size);
alignas(64) CUtensorMap tensor_map_output_rowwise{}; }
alignas(64) CUtensorMap tensor_map_output_colwise{};
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols,
MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size);
}
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems;
typeToNumBits(input.dtype())); constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
} constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
if (use_rowwise_scaling) { DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, constexpr size_t buff_size_aligned_out =
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
typeToNumBits(output->dtype()));
} constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
if (use_colwise_scaling) { constexpr size_t in_mem = elt_input_mem + act_input_mem;
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows,
cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
typeToNumBits(output->dtype())); const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
} const size_t out_mem = out_rowwise_mem + out_colwise_mem;
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
SCALE_DIM_Y, SCALE_DIM_X><<<grid, block, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0};
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, // This kernel will only be called on sm100+, so no need to check sm_arch
reinterpret_cast<const float *>(noop->data.dptr), workspace_ptr, amax_ptr, cudaLaunchAttribute attribute[1];
rows, cols, scale_stride_rowwise, scale_stride_colwise); attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute;
if constexpr (IS_DBIAS) { cfg.numAttrs = 1;
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*) switch (scaling_type) {
); // NOLINT(*) case ScalingType::ROWWISE:
); // NOLINT(*) cudaFuncSetAttribute(
); // NOLINT(*) cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
break;
}
if constexpr (IS_DBIAS) {
reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}); // NOLINT(*)
); // NOLINT(*)
#endif #endif
} }
...@@ -1132,8 +1290,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) { ...@@ -1132,8 +1290,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) {
bool dimensions_supported_by_TMA(const Tensor *const t) { bool dimensions_supported_by_TMA(const Tensor *const t) {
const size_t cols = t->flat_last_dim(); const size_t cols = t->flat_last_dim();
constexpr int TMA_bytes = 16; constexpr size_t TMA_bytes = 16;
const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype());
return cols % alignment_requirement == 0; return cols % alignment_requirement == 0;
} }
...@@ -1149,8 +1307,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons ...@@ -1149,8 +1307,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
if (!IS_DBIAS && !IS_DACT) { if (!IS_DBIAS && !IS_DACT) {
if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) && is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment)) { is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 // Aligned AND FP8
cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream); cast_fp8_1D<IS_ACT, ParamOP, OP>(input, output, stream);
} else { } else {
...@@ -1159,9 +1317,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons ...@@ -1159,9 +1317,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons
} }
} else if (!IS_DBIAS && IS_DACT) { } else if (!IS_DBIAS && IS_DACT) {
if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) &&
is_aligned_tensor_data(input, TMA_gmem_alignment) && is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*output, TMA_gmem_alignment) && is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) &&
is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) {
// Aligned AND FP8 (+dAct) // Aligned AND FP8 (+dAct)
cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace, cast_fp8_2D<IS_DBIAS, IS_DACT, ParamOP, OP>(input, act_input, output, dbias, workspace,
stream); stream);
...@@ -1194,7 +1352,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const ...@@ -1194,7 +1352,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const
if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) {
// zhongboz: should we just ignore IS_ACT here? // zhongboz: should we just ignore IS_ACT here?
NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) +
" on GPU with compute capability < 10.0."); " or IS_DBIAS=true" + " on GPU with compute capability < 10.0.");
} }
switch (output->scaling_mode) { switch (output->scaling_mode) {
case NVTE_DELAYED_TENSOR_SCALING: { case NVTE_DELAYED_TENSOR_SCALING: {
......
...@@ -55,6 +55,22 @@ void *get_symbol(const char *symbol, int cuda_version) { ...@@ -55,6 +55,22 @@ void *get_symbol(const char *symbol, int cuda_version) {
return entry_point; return entry_point;
} }
void ensure_context_exists() {
static thread_local bool need_check = []() {
CUcontext context;
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context);
if (context == nullptr) {
// Add primary context to context stack
CUdevice device;
NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device());
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device);
NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context);
NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device);
}
return false;
}();
}
} // namespace cuda_driver } // namespace cuda_driver
} // namespace transformer_engine } // namespace transformer_engine
...@@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) { ...@@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) {
return (*func)(args...); return (*func)(args...);
} }
/*! \brief Ensure that the calling thread has a CUDA context
*
* Each thread maintains a stack of CUDA contexts. If the calling
* thread has an empty stack, the primary context is added to the
* stack.
*/
void ensure_context_exists();
} // namespace cuda_driver } // namespace cuda_driver
} // namespace transformer_engine } // namespace transformer_engine
......
...@@ -87,8 +87,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -87,8 +87,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// const int thread_offset_X_colwise = tid_colwise_X; // const int thread_offset_X_colwise = tid_colwise_X;
// The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned
__shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
__shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X];
constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM;
constexpr int transaction_size = shmem_buff_size; constexpr int transaction_size = shmem_buff_size;
...@@ -169,7 +169,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -169,7 +169,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X;
const e8m0_t biased_exponent = scales_ptr[scale_idx]; const e8m0_t biased_exponent = scales_ptr[scale_idx];
const float block_scale = exp2f(static_cast<float>(biased_exponent) - FP32_EXPONENT_BIAS); const float block_scale = ptx::exp2f(biased_exponent);
if constexpr (USE_ROWWISE_SCALING) { if constexpr (USE_ROWWISE_SCALING) {
Vec<IType, ELEMS_PER_THREAD> in; Vec<IType, ELEMS_PER_THREAD> in;
......
...@@ -35,6 +35,7 @@ struct MultiPaddingArgs { ...@@ -35,6 +35,7 @@ struct MultiPaddingArgs {
int padded_num_rows_list[kMaxTensorsPerKernel]; int padded_num_rows_list[kMaxTensorsPerKernel];
// Input matrix widths // Input matrix widths
int row_length_list[kMaxTensorsPerKernel]; int row_length_list[kMaxTensorsPerKernel];
// Prefix sum (with leading zero) of CUDA blocks needed for each
// tensor // tensor
int block_range[kMaxTensorsPerKernel + 1]; int block_range[kMaxTensorsPerKernel + 1];
// Number of tensors being processed by kernel // Number of tensors being processed by kernel
......
...@@ -104,6 +104,61 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 ...@@ -104,6 +104,61 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr uint32_t FP32_MANTISSA_BITS = 23;
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1
: __int_as_float((254 - biased_exp)
<< FP32_MANTISSA_BITS); // 127 - (biased_exp - 127)
}
__device__ __forceinline__ float exp2f(e8m0_t biased_exp) {
return __int_as_float(biased_exp << FP32_MANTISSA_BITS);
}
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor
...@@ -169,6 +224,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { ...@@ -169,6 +224,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() {
asm volatile("fence.proxy.async.shared::cta;"); asm volatile("fence.proxy.async.shared::cta;");
} }
template <typename T>
struct alignas(2 * sizeof(T)) FPx2 {
T x;
T y;
};
using floatx2 = FPx2<float>;
using bf16x2 = FPx2<bf16>;
using fp16x2 = FPx2<fp16>;
using fp8e4m3x2 = FPx2<fp8e4m3>;
using fp8e5m2x2 = FPx2<fp8e5m2>;
static_assert(sizeof(floatx2) == 8);
static_assert(sizeof(bf16x2) == 4);
static_assert(sizeof(fp16x2) == 4);
static_assert(sizeof(fp8e4m3x2) == 2);
static_assert(sizeof(fp8e5m2x2) == 2);
// SIMD like "Fused" cast + multiplication (x2)
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in,
const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
"mul.f32x2 val_pair, %1, %2; \n\t"
"mov.b64 {val2,val1}, val_pair; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "l"(reinterpret_cast<const uint64_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_bf16; \n\t"
".reg.b16 val2_bf16; \n\t"
"mov.b32 {val1_bf16, val2_bf16} , %1; \n\t"
"cvt.f32.bf16 val1, val1_bf16; \n\t"
"cvt.f32.bf16 val2, val2_bf16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) {
asm volatile(
"{\n"
".reg.b64 val_pair_before; \n\t"
".reg.b64 val_pair_after; \n\t"
".reg.b32 val1; \n\t"
".reg.b32 val2; \n\t"
".reg.b16 val1_fp16; \n\t"
".reg.b16 val2_fp16; \n\t"
"mov.b32 {val1_fp16, val2_fp16} , %1; \n\t"
"cvt.f32.f16 val1, val1_fp16; \n\t"
"cvt.f32.f16 val2, val2_fp16; \n\t"
"mov.b64 val_pair_before, {val1,val2}; \n\t"
"mul.f32x2 val_pair_after, val_pair_before, %2; \n\t"
"mov.b64 {val2,val1}, val_pair_after; \n\t"
"cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t"
"}"
: "=h"(reinterpret_cast<uint16_t &>(out))
: "r"(reinterpret_cast<const uint32_t &>(in)),
"l"(reinterpret_cast<const uint64_t &>(scale)));
}
__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) {
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) {
asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;"
: "=r"(reinterpret_cast<uint32_t &>(dst))
: "r"(reinterpret_cast<const uint32_t &>(p1)),
"r"(reinterpret_cast<const uint32_t &>(p2)));
}
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
} // namespace ptx } // namespace ptx
......
...@@ -99,7 +99,9 @@ ...@@ -99,7 +99,9 @@
transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \
.value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \
.value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \
.value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \
.value("EXTERNAL_BULK_OVERLAP_AG", \
transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \
py::class_<transformer_engine::CommOverlapCore, \ py::class_<transformer_engine::CommOverlapCore, \
std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \ std::shared_ptr<transformer_engine::CommOverlapCore>>(m, "CommOverlapCore", \
pybind11::module_local()) \ pybind11::module_local()) \
......
...@@ -59,6 +59,7 @@ class Kernel { ...@@ -59,6 +59,7 @@ class Kernel {
template <typename... ArgTs> template <typename... ArgTs>
void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, void launch(int device_id, const dim3 grid_dim, const dim3 block_dim,
unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) {
cuda_driver::ensure_context_exists();
void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...}; void *arg_ptrs[] = {const_cast<void *>(static_cast<const void *>(&args))...};
NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y,
grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes,
......
...@@ -183,6 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -183,6 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -196,10 +197,11 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -196,10 +197,11 @@ __launch_bounds__(unary_kernel_threads) __global__
for (int i = 0; i < nvec; ++i) { for (int i = 0; i < nvec; ++i) {
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
ComputeType temp = OP(val, p); ComputeType temp = OP(val, p);
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
temp = temp * s; temp = temp * s;
} }
if constexpr (is_int8<OutputType>::value) { if constexpr (is_int8<OutputType>::value) {
...@@ -210,16 +212,17 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -210,16 +212,17 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
}
} }
}
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
...@@ -239,6 +242,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -239,6 +242,7 @@ __launch_bounds__(unary_kernel_threads) __global__
VectorizedStorer<OutputType, nvec, aligned> storer(output, N); VectorizedStorer<OutputType, nvec, aligned> storer(output, N);
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -254,10 +258,11 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -254,10 +258,11 @@ __launch_bounds__(unary_kernel_threads) __global__
const ComputeType val = static_cast<ComputeType>(loader.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader.separate()[i]);
const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]); const ComputeType g = static_cast<ComputeType>(grad_loader.separate()[i]);
ComputeType temp = OP(val, p) * g; ComputeType temp = OP(val, p) * g;
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
temp = temp * s; temp = temp * s;
} }
if constexpr (is_int8<OutputType>::value) { if constexpr (is_int8<OutputType>::value) {
...@@ -268,16 +273,17 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -268,16 +273,17 @@ __launch_bounds__(unary_kernel_threads) __global__
} }
storer.store(tid, N); storer.store(tid, N);
} }
if constexpr (is_fp8<OutputType>::value || is_int8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
}
} }
}
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
...@@ -412,6 +418,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -412,6 +418,7 @@ __launch_bounds__(unary_kernel_threads) __global__
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -431,25 +438,28 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -431,25 +438,28 @@ __launch_bounds__(unary_kernel_threads) __global__
const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]); const ComputeType val = static_cast<ComputeType>(loader0.separate()[i]);
const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]); const ComputeType val2 = static_cast<ComputeType>(loader1.separate()[i]);
ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2); ComputeType temp = static_cast<ComputeType>(Activation(val, p) * val2);
if constexpr (is_fp8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(temp), max); max = fmaxf(fabsf(temp), max);
}
if constexpr (is_fp8<OutputType>::value) {
temp = temp * s; temp = temp * s;
} }
storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp)); storer.separate()[i] = static_cast<OutputType>(static_cast<ComputeType>(temp));
} }
storer.store(id_x, n); storer.store(id_x, n);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
}
} }
}
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
...@@ -503,6 +513,7 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -503,6 +513,7 @@ __launch_bounds__(unary_kernel_threads) __global__
const size_t M = num_aligned_elements * m; const size_t M = num_aligned_elements * m;
ComputeType max = 0; ComputeType max = 0;
ComputeType s = 1; ComputeType s = 1;
const bool requires_amax = (amax != nullptr);
if constexpr (is_fp8<OutputType>::value) { if constexpr (is_fp8<OutputType>::value) {
if (scale != nullptr) s = *scale; if (scale != nullptr) s = *scale;
} }
...@@ -530,11 +541,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -530,11 +541,13 @@ __launch_bounds__(unary_kernel_threads) __global__
ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in;
ComputeType after_dgate = grad_val * Activation(gelu_in, p); ComputeType after_dgate = grad_val * Activation(gelu_in, p);
if constexpr (is_fp8<OutputType>::value) { if (requires_amax) {
__builtin_assume(max >= 0); __builtin_assume(max >= 0);
max = fmaxf(fabsf(after_dgelu), max); max = fmaxf(fabsf(after_dgelu), max);
after_dgelu = after_dgelu * s;
max = fmaxf(fabsf(after_dgate), max); max = fmaxf(fabsf(after_dgate), max);
}
if constexpr (is_fp8<OutputType>::value) {
after_dgelu = after_dgelu * s;
after_dgate = after_dgate * s; after_dgate = after_dgate * s;
} }
...@@ -544,16 +557,17 @@ __launch_bounds__(unary_kernel_threads) __global__ ...@@ -544,16 +557,17 @@ __launch_bounds__(unary_kernel_threads) __global__
storer0.store(id_x, n); storer0.store(id_x, n);
storer1.store(id_x, n); storer1.store(id_x, n);
} }
if constexpr (is_fp8<OutputType>::value) {
// Reduce amax over block // Reduce amax over block
if (amax != nullptr) { if (requires_amax) {
max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id); max = reduce_max<unary_kernel_threads / THREADS_PER_WARP>(max, warp_id);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
static_assert(std::is_same<ComputeType, float>::value); static_assert(std::is_same<ComputeType, float>::value);
atomicMaxFloat(amax, max); atomicMaxFloat(amax, max);
}
} }
}
if constexpr (is_fp8<OutputType>::value) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) {
reciprocal<ComputeType>(scale_inv, s); reciprocal<ComputeType>(scale_inv, s);
......
...@@ -992,10 +992,7 @@ using fp8e5m2 = __nv_fp8_e5m2; ...@@ -992,10 +992,7 @@ using fp8e5m2 = __nv_fp8_e5m2;
using e8m0_t = uint8_t; using e8m0_t = uint8_t;
using int8 = int8_t; using int8 = int8_t;
constexpr uint32_t FP32_MANTISSA_BITS = 23; enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 };
constexpr uint32_t FP32_EXPONENT_BIAS = 127;
enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 };
template <typename T> template <typename T>
struct Numeric_Traits; struct Numeric_Traits;
...@@ -1027,51 +1024,6 @@ struct Quantized_Limits { ...@@ -1027,51 +1024,6 @@ struct Quantized_Limits {
static constexpr float emax_rcp = 1.0 / emax; static constexpr float emax_rcp = 1.0 / emax;
}; };
#ifdef __HIP_PLATFORM_AMD__
#define __CUDA_ARCH_HAS_FEATURE__(FEATURE) \
((__CUDA_ARCH__ >= 100 && FEATURE == SM100_ALL) || \
(__CUDA_ARCH__ >= 101 && FEATURE == SM101_ALL) || \
(__CUDA_ARCH__ >= 120 && FEATURE == SM120_ALL))
#endif
__device__ __forceinline__ e8m0_t float_to_e8m0(float val) {
// TODO: nan/inf needs to be set for any value
// of nan/inf in input not just amax.
if (isnan(val)) {
return 0xFF;
}
if (isinf(val)) {
return 0xFE;
}
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL)))
uint16_t out;
asm volatile(
"{\n"
"cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n"
"}"
: "=h"(out)
: "f"(val));
return *reinterpret_cast<e8m0_t *>(&out);
#else
if (val == 0.0f) {
return 0x00;
}
uint32_t val_u32 = *reinterpret_cast<uint32_t *>(&val);
e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS);
uint32_t mantissa = val_u32 & 0x7FFFFF;
// Round up exponent and deal with satfinite.
if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) {
++exponent;
}
return exponent;
#endif
}
__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) {
return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast<float>(biased_exp));
}
} // namespace transformer_engine } // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_
...@@ -5,7 +5,8 @@ ...@@ -5,7 +5,8 @@
"""API definition for nvidia-dlframework-inspect.""" """API definition for nvidia-dlframework-inspect."""
import copy import copy
from typing import Dict, Union import warnings
from typing import Dict, Union, Tuple, Optional
from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper
from nvdlfw_inspect.registry import Registry from nvdlfw_inspect.registry import Registry
...@@ -101,13 +102,23 @@ required_kwargs = { ...@@ -101,13 +102,23 @@ required_kwargs = {
class TEDefaultFeatures: class TEDefaultFeatures:
"""Transformer Engine API calls default behavior.""" """Transformer Engine API calls default behavior."""
def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: int) -> bool: def fp8_gemm_enabled(
self,
config: Dict,
layer_name: str,
gemm: str,
iteration: int,
) -> bool | Tuple[bool, Optional[int]]:
""" """
If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled, If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled,
then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*. then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*.
If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled, If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled,
the result of this call does not matter. the result of this call does not matter.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled.
It can return (bool, None) if the feature will never be enabled for that layer and gemm.
Returning the next enabled iteration can help optimize CPU usage.
Parameters Parameters
---------- ----------
...@@ -122,9 +133,9 @@ class TEDefaultFeatures: ...@@ -122,9 +133,9 @@ class TEDefaultFeatures:
Returns Returns
------- -------
bool - default is True Union[bool, Tuple[bool, Optional[int]]] - default is (True, None)
""" """
return True # if it is false, fp8_gemm will be turned off. Otherwise nothing happens. return True, None # if it is false, fp8_gemm will be turned off. Otherwise nothing happens.
def modify_tensor_enabled( def modify_tensor_enabled(
self, self,
...@@ -133,9 +144,16 @@ class TEDefaultFeatures: ...@@ -133,9 +144,16 @@ class TEDefaultFeatures:
gemm: str, gemm: str,
tensor_name: str, tensor_name: str,
iteration: int, iteration: int,
) -> bool: ) -> bool | Tuple[bool, Optional[int]]:
""" """
It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what. It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name.
It has **higher priority** than fp8_gemm; if *modify_tensor_enabled* returns True or (True, next_enabled_iter),
then modify_tensor call is invoked for the respective tensor no matter what.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor.
Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large.
Returning only a bool is deprecated.
Parameters Parameters
---------- ----------
...@@ -153,9 +171,9 @@ class TEDefaultFeatures: ...@@ -153,9 +171,9 @@ class TEDefaultFeatures:
Returns Returns
------- -------
bool - default is False Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
""" """
return False return False, None
def modify_tensor( def modify_tensor(
self, self,
...@@ -167,7 +185,7 @@ class TEDefaultFeatures: ...@@ -167,7 +185,7 @@ class TEDefaultFeatures:
default_quantizer: Quantizer, default_quantizer: Quantizer,
iteration: int, iteration: int,
out: Union[torch.Tensor, QuantizedTensor], out: Union[torch.Tensor, QuantizedTensor],
) -> Union[torch.Tensor, QuantizedTensor, None]: ) -> torch.Tensor | QuantizedTensor | None:
""" """
It allows tensor modification. It allows tensor modification.
For example, feature `FakeQuant` uses it to emulate casting to FP8. For example, feature `FakeQuant` uses it to emulate casting to FP8.
...@@ -227,6 +245,9 @@ class TEDefaultFeatures: ...@@ -227,6 +245,9 @@ class TEDefaultFeatures:
layer_name: str, layer_name: str,
tensor_name: str, tensor_name: str,
tensor: torch.Tensor, tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[torch.Tensor],
columnwise_quantized_tensor: Optional[torch.Tensor],
quantizer: Optional[Quantizer],
iteration: int, iteration: int,
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
) -> None: ) -> None:
...@@ -243,6 +264,12 @@ class TEDefaultFeatures: ...@@ -243,6 +264,12 @@ class TEDefaultFeatures:
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor tensor: torch.Tensor
tensor in high precision, tensor in high precision,
rowwise_quantized_tensor: Optional[torch.Tensor]
rowwise quantized tensor,
columnwise_quantized_tensor: Optional[torch.Tensor]
columnwise quantized tensor,
quantizer: Optional[Quantizer]
quantizer,
iteration: int iteration: int
iteration number - equal to the number of times `debug_api.step()` was called. iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup tp_group: torch.distributed.ProcessGroup
...@@ -260,12 +287,15 @@ class TEDefaultFeatures: ...@@ -260,12 +287,15 @@ class TEDefaultFeatures:
config: Dict, config: Dict,
layer_name: str, layer_name: str,
tensor_name: str, tensor_name: str,
gemm: str,
tensor: torch.Tensor, tensor: torch.Tensor,
iteration: int, iteration: int,
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
rowwise: bool,
) -> None: ) -> None:
""" """
This is deprecated call, we advise to use *inspect_tensor* instead.
Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization. Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization.
Parameters Parameters
...@@ -278,8 +308,6 @@ class TEDefaultFeatures: ...@@ -278,8 +308,6 @@ class TEDefaultFeatures:
one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`],
tensor: torch.Tensor tensor: torch.Tensor
tensor in fp8 or processed tensor after the modify_tensor call, tensor in fp8 or processed tensor after the modify_tensor call,
gemm: str
one of [`fprop`, `dgrad`, `wgrad`],
iteration: int iteration: int
iteration number - equal to the number of times `debug_api.step()` was called. iteration number - equal to the number of times `debug_api.step()` was called.
tp_group: torch.distributed.ProcessGroup tp_group: torch.distributed.ProcessGroup
...@@ -298,9 +326,15 @@ class TEDefaultFeatures: ...@@ -298,9 +326,15 @@ class TEDefaultFeatures:
layer_name: str, layer_name: str,
tensor_name: str, tensor_name: str,
iteration: int, iteration: int,
) -> bool: ) -> bool | Tuple[bool, Optional[int]]:
""" """
It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked. It is a routing call, which is run at the initialization of the layer.
Determines if *inspect_tensor* for a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer and tensor.
Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large.
Returning only a bool is deprecated.
Parameters Parameters
---------- ----------
...@@ -316,9 +350,9 @@ class TEDefaultFeatures: ...@@ -316,9 +350,9 @@ class TEDefaultFeatures:
Returns Returns
------- -------
bool - default is False Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
""" """
return False return False, None
def inspect_tensor_postquantize_enabled( def inspect_tensor_postquantize_enabled(
self, self,
...@@ -327,11 +361,18 @@ class TEDefaultFeatures: ...@@ -327,11 +361,18 @@ class TEDefaultFeatures:
gemm: str, gemm: str,
tensor_name: str, tensor_name: str,
iteration: int, iteration: int,
) -> bool: ) -> bool | Tuple[bool, Optional[int]]:
""" """
This is deprecated call, we advise to use *inspect_tensor* and *inspect_tensor_enabled* instead.
It is a routing call, which is run at the initialization of the layer. It is a routing call, which is run at the initialization of the layer.
If it returns true, then *inspect_tensor_postquantize* for Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked.
a given GEMM and tensor will be invoked.
This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled.
It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name.
Returning the next enabled iteration can help optimize CPU usage,
especially when the interval between inspect_tensor_postquantize is large.
Returning only a bool is deprecated.
Parameters Parameters
---------- ----------
...@@ -349,9 +390,9 @@ class TEDefaultFeatures: ...@@ -349,9 +390,9 @@ class TEDefaultFeatures:
Returns Returns
------- -------
bool - default is False Union[bool, Tuple[bool, Optional[int]]] - default is (False, None)
""" """
return False return False, None
@Registry.register_namespace_api(namespace="transformer_engine") @Registry.register_namespace_api(namespace="transformer_engine")
...@@ -371,8 +412,8 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -371,8 +412,8 @@ class TransformerEngineAPI(BaseNamespaceAPI):
"modify_tensor": ["tensor_name", "gemm"], "modify_tensor": ["tensor_name", "gemm"],
"inspect_tensor": ["tensor_name"], "inspect_tensor": ["tensor_name"],
"inspect_tensor_postquantize": ["tensor_name"], "inspect_tensor_postquantize": ["tensor_name"],
"inspect_tensor_enabled": ["tensor_name"], "inspect_tensor_enabled": ["tensor_name", "iteration"],
"inspect_tensor_postquantize_enabled": ["tensor_name"], "inspect_tensor_postquantize_enabled": ["tensor_name", "iteration"],
"modify_tensor_enabled": ["tensor_name"], "modify_tensor_enabled": ["tensor_name"],
} }
...@@ -420,7 +461,7 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -420,7 +461,7 @@ class TransformerEngineAPI(BaseNamespaceAPI):
def output_assertions_hook(self, api_name, ret, **kwargs): def output_assertions_hook(self, api_name, ret, **kwargs):
"""Output hooks used to check correctness of the outputs of the API calls.""" """Output hooks used to check correctness of the outputs of the API calls."""
if "enabled" in api_name or api_name == "fp8_gemm": if "enabled" in api_name or api_name == "fp8_gemm":
assert isinstance(ret, bool) assert isinstance(ret, (bool, tuple))
if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]:
assert ret is None assert ret is None
if api_name == "modify_tensor": if api_name == "modify_tensor":
...@@ -432,6 +473,57 @@ class TransformerEngineAPI(BaseNamespaceAPI): ...@@ -432,6 +473,57 @@ class TransformerEngineAPI(BaseNamespaceAPI):
if kwargs["dtype"] is not None: if kwargs["dtype"] is not None:
assert ret.dtype == kwargs["dtype"] assert ret.dtype == kwargs["dtype"]
def call_feature(self, call, feat_config, layer_name, **kwargs):
"""
For backward compatibility, remove kwargs that are not needed for the call
"""
if call.__name__ == "inspect_tensor":
kwargs_copy = kwargs.copy()
for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]:
if k not in call.__code__.co_varnames:
kwargs_copy.pop(k)
else:
kwargs_copy = kwargs
if call.__name__ == "inspect_tensor_postquantize":
warnings.warn(
"inspect_tensor_postquantize is deprecated, use inspect_tensor instead.",
DeprecationWarning,
)
return call(feat_config, layer_name, **kwargs_copy)
def handle_multi_feature_output(
self, api_name, multi_feature_outputs, features_to_invoke, **kwargs
):
"""
Handle multi-tensor output of the API calls.
"""
if "enabled" in api_name:
# *_enabled feature calls can return bool, or tuple (bool, Optional[int]).
# If any of them returns bool, then we return bool - this means that we cannot state anything
# about enablement in the next steps.
# If all of them return a tuple (bool, Optional[int]), we return the minimum value,
# representing the number of steps after the feature will be enabled next time.
# If the second value is None, that means that the feature will never be enabled.
all_ret_tuple = all(
isinstance(feature_output, tuple) for feature_output in multi_feature_outputs
)
if all_ret_tuple:
run_current = any(feature_output[0] for feature_output in multi_feature_outputs)
next_iter = None
for feature_output in multi_feature_outputs:
if next_iter is None:
next_iter = feature_output[1]
elif feature_output[1] is not None:
next_iter = min(next_iter, feature_output[1])
return run_current, next_iter
run_current = any(feature_output for feature_output in multi_feature_outputs)
return run_current, None
return super().handle_multi_feature_output(
api_name, multi_feature_outputs, features_to_invoke, **kwargs
)
def step(self): def step(self):
"""This function is called by the nvidia-dlframework-inspect after every debug_api.step()""" """This function is called by the nvidia-dlframework-inspect after every debug_api.step()"""
STATS_BUFFERS.log_stats() STATS_BUFFERS.log_stats()
......
...@@ -50,4 +50,4 @@ class DisableFP8GEMM(TEConfigAPIMapper): ...@@ -50,4 +50,4 @@ class DisableFP8GEMM(TEConfigAPIMapper):
# If this feature is invoked, then FP8 GEMM is disabled. # If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI # If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True. # is that fp8_gemm() API call returns True.
return False return False, iteration + 1
...@@ -41,7 +41,7 @@ class DisableFP8Layer: ...@@ -41,7 +41,7 @@ class DisableFP8Layer:
# If this feature is invoked, then FP8 GEMM is disabled. # If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI # If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True. # is that fp8_gemm() API call returns True.
return False return False, iteration + 1
def parse_config_and_api(self, config, **_kwargs): def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API """Determines whether to run the API
......
...@@ -127,14 +127,14 @@ class FakeQuant(TEConfigAPIMapper): ...@@ -127,14 +127,14 @@ class FakeQuant(TEConfigAPIMapper):
self, config, layer_name: str, gemm: str, iteration: int self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution.""" """API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False return False, None
@api_method @api_method
def modify_tensor_enabled( def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward.""" """API call used to determine whether to run process_tensor() in the forward."""
return True return True, iteration + 1
@api_method @api_method
def modify_tensor( def modify_tensor(
......
...@@ -4,54 +4,119 @@ ...@@ -4,54 +4,119 @@
"""LogFp8TensorStats Feature support for nvidia-dlframework-inspect""" """LogFp8TensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union from typing import Dict, Optional, List, Tuple
from contextlib import contextmanager
import torch import torch
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats
from nvdlfw_inspect.registry import Registry, api_method from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import (
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor Float8Quantizer,
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase Float8CurrentScalingQuantizer,
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase )
from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer
from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter
ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"]
def _get_recipe_name(quantizer: Optional[Quantizer]):
if quantizer is None:
return ""
if isinstance(quantizer, Float8Quantizer):
return "fp8_delayed_scaling"
if isinstance(quantizer, Float8CurrentScalingQuantizer):
return "fp8_current_scaling"
if isinstance(quantizer, MXFP8Quantizer):
return "mxfp8"
if isinstance(quantizer, Float8BlockQuantizer):
return "fp8_block_scaling"
raise ValueError(f"Unsupported quantizer type: {type(quantizer)}")
def _get_new_quantizer(recipe_name, fp8_dtype):
if recipe_name == "fp8_block_scaling":
return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
if recipe_name == "fp8_current_scaling":
return Float8CurrentScalingQuantizer(
fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True
)
if recipe_name == "mxfp8":
return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True)
if recipe_name == "fp8_delayed_scaling":
raise ValueError("Cannot recreate quantizer for fp8_delayed_scaling")
raise ValueError(f"Unsupported recipe name: {recipe_name}")
@Registry.register_feature(namespace="transformer_engine") @Registry.register_feature(namespace="transformer_engine")
class LogFp8TensorStats(BaseLogTensorStats): class LogFp8TensorStats(BaseLogTensorStats):
""" """
This feature handles logging of FP8 tensor stats. Logs statistics of quantized tensors.
Supports computing statistics for current recipe, but also
allows to see what would happend if different recipes were used for these tensors in current iteration.
For example, during delayed-scaling training you may wish to track
"current_scaling_underflows%" to measure the accuracy of the current scaling
factors; note that this requires an extra cast and therefore adds overhead.
Using a logging frequency (`freq`) greater than 1 is recommended in this case.
Computing the stats matching the training recipe does not require an extra cast.
Statistics are identified by the pattern `<recipe>_<stat>` with optional `_columnwise` suffix (e.g.
`delayed_scaling_underflows%` or `mxfp8_scale_inv_min_columnwise`).
One can provide `<stat>` only, then the current training recipe is used.
In a distributed setting, the auxiliary stats are computed on each rank and gathered after Stats for delayed-scaling cannot be collected if delayed-scaling is not the current training recipe.
the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log
stats!
`LogFp8TensorStats` supports micro-batching. If multiple forward/backward passes are invoked In distributed runs each rank first computes its local statistics; the values
per `debug_api.step()`, then stats for all tensors except weights will be accumulated. are gathered the next time `debug_api.step()` is called. Remember to call
`debug_api.step()` every training step so the logs are flushed.
`LogFp8TensorStats` can induce significant overhead. To mitigate this issue, logging stats The feature is micro-batch aware: if several forward/backward passes occur
with `freq > 1` is recommended. If `LogFp8TensorStats` is not used in a given step, the between successive `debug_api.step()` calls, statistics are accumulated for all
overhead is smaller. If no other feature is used for the layer, the TE layer will tensors except weights.
run as fast as it would without `debug_api` initialized.
Collecting FP8 statistics is expensive. Choosing a larger `freq` reduces the
overhead, and if the feature is skipped for a step the additional cost is
minimal. When no other debug feature is active, the layer runs at normal
Transformer Engine speed.
Parameters Parameters
---------- ----------
stats: List[str] stats: List[str]
list of statistics to log Each stat is a string of the form `<recipe>_<stat>`, with an optional `_columnwise` suffix (i.e., `<recipe>_<stat>_columnwise`).
If only `<recipe>` is omitted, the current training recipe is used.
For mxfp8 and fp8_block_scaling `_columnwise` suffix can be provided. Then stat is computed on columnwise(transpose)
version of the tensor, which can be numerically different from rowwise (non-transpose) tensors.
"_columnwise" suffix is not supported for fp8_delayed_scaling and fp8_current_scaling.
recipes:
- fp8_delayed_scaling,
- fp8_current_scaling,
- mxfp8,
- fp8_block_scaling,
stats:
- underflows% - percentage of non-zero elements of tensor clipped to 0 after quantization,
- overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling,
- scale_inv_min - minimum of the inverse of the scaling factors,
- scale_inv_max - maximum of the inverse of the scaling factors,
- mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements,
- underflows% - percentage of elements of the tensor equal to 0,
tensors/tensors_struct: List[str] tensors/tensors_struct: List[str]
list of tensors to log list of tensors to log
- activation,
- gradient,
- weight,
- activation
- gradient
- weight
freq: Optional[int], default = 1 freq: Optional[int], default = 1
frequency of logging stats, stats will be logged every `freq` steps frequency of logging stats, stats will be logged every `freq` steps
start_step: Optional[int], default = None start_step: Optional[int], default = None
...@@ -74,7 +139,7 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -74,7 +139,7 @@ class LogFp8TensorStats(BaseLogTensorStats):
enabled: True enabled: True
tensors_struct: tensors_struct:
- tensor: activation - tensor: activation
stats: [underflows%] stats: [mxfp8_underflows%]
freq: 1 freq: 1
- tensor: gradient - tensor: gradient
stats: [underflows%] stats: [underflows%]
...@@ -83,42 +148,147 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -83,42 +148,147 @@ class LogFp8TensorStats(BaseLogTensorStats):
end_step: 80 end_step: 80
""" """
def _get_supported_stats_list(self): def check_if_stat_is_supported(self, stat: str, current_recipe: str):
"""Returns stats this feature can log.""" """Returns True if stat is supported, raises ValueError otherwise."""
return {"underflows%"} columnwise = stat.endswith("_columnwise")
if columnwise:
stat = stat[: -len("_columnwise")]
recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe)
stat_without_recipe = stat.replace(recipe_from_stat + "_", "")
if current_recipe == "" and recipe_from_stat == "":
raise ValueError(
f"Stat {stat} does not contain a recipe name and the current recipe is not set."
)
if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES:
raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}")
if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise:
raise ValueError(
f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for"
" fp8_delayed_scaling and fp8_current_scaling."
)
if recipe_from_stat == "fp8_delayed_scaling" and stat_without_recipe == "overflows%":
return True
if recipe_from_stat in ["fp8_block_scaling"] and torch.cuda.get_device_capability()[0] < 9:
raise ValueError(f"Stat {stat} needs Hopper or later GPU.")
if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10:
raise ValueError(f"Stat {stat} needs Blackwell or later GPU.")
supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"]
if stat_without_recipe not in supported_stats:
raise ValueError(
f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}"
)
return True
def get_recipe_from_stat(self, stat: str, default_recipe: str = ""):
"""Returns the recipe name from the stat string."""
columnwise_stat = stat.endswith("_columnwise")
for recipe_name in ALL_RECIPE_NAMES:
if recipe_name in stat:
return recipe_name, columnwise_stat
return default_recipe, columnwise_stat
@contextmanager
def update_aux_dict(
self,
aux_dict: Dict,
recipe_name: str,
quantized_tensor: QuantizedTensor,
quantizer: Quantizer,
original_tensor: torch.Tensor,
recipes_in_stats: List[Tuple[str, bool]],
):
"""
Updates the aux_dict with the quantized tensor for each recipe provided in recipes_in_stats.
It allows to compute stats for different recipes in the same iteration,
without recomputing the quantized tensor for each recipe for each stat.
Also updates usage of the quantized tensor with rowwise and columnwise usage.
Yields the aux_dict.
Needs to clean after usage, because it possibly change the usage of the quantized tensor.
"""
fp8_dtype = None
if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]:
assert isinstance(
quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer)
)
fp8_dtype = quantizer.dtype
aux_dict = {
recipe_name: quantized_tensor,
}
old_rowwise_usage = quantizer.rowwise_usage
old_columnwise_usage = quantizer.columnwise_usage
for cur_recipe_name, cur_columnwise_stat in recipes_in_stats:
if recipe_name is not cur_recipe_name:
quantizer = _get_new_quantizer(cur_recipe_name, fp8_dtype)
aux_dict[cur_recipe_name] = quantizer(original_tensor)
elif isinstance(quantized_tensor, QuantizedTensor):
if cur_columnwise_stat:
quantized_tensor.update_usage(columnwise_usage=True)
else:
quantized_tensor.update_usage(rowwise_usage=True)
aux_dict[""] = quantized_tensor
aux_dict[cur_recipe_name] = quantized_tensor
try:
yield aux_dict
finally:
if isinstance(quantized_tensor, QuantizedTensor):
quantized_tensor.update_usage(
rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage
)
@api_method @api_method
def inspect_tensor_postquantize_enabled( def inspect_tensor_enabled(
self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""API call used to determine whether to run inspect_tensor_postquantize() in the forward.""" """API call used to determine whether to run inspect_tensor_postquantize() in the forward."""
# check whether logging should happen in this iteration run_current, next_iter = next_enabled_iter(
return self._check_params(config, layer_name, iteration=iteration) config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method @api_method
def inspect_tensor_postquantize( def inspect_tensor(
self, self,
config: Dict, config: Dict,
layer_name: str, layer_name: str,
tensor_name: str, tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
rowwise: bool,
iteration: int, iteration: int,
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
): ):
""" """
API call used to collect the data about the tensor after process_tensor()/quantization. API call used to collect the data about the tensor after process_tensor()/quantization.
""" """
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."
assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], ( quantized_tensor = rowwise_quantized_tensor
f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using" assert isinstance(
" log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors." quantized_tensor, QuantizedTensor
) ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
recipe_name = _get_recipe_name(quantizer)
# This API can be invoked twice - with the tensor and with the transpose. for stat in config["stats"]:
# We want to collect the stats once. self.check_if_stat_is_supported(stat, recipe_name)
if not rowwise:
return # tensor was already seen rowwise in the other gemm
options = ( options = (
config.get("start_step", None), config.get("start_step", None),
...@@ -127,19 +297,9 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -127,19 +297,9 @@ class LogFp8TensorStats(BaseLogTensorStats):
"fp8", "fp8",
) )
skip_reduction = False skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
reduction_group = debug_api.get_tensor_reduction_group() tensor_name, tp_group
reduce_within_microbatch = tensor_name != "weight" )
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]:
assert (
stat in self._get_supported_stats_list()
), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported."
STATS_BUFFERS.try_add_buffer( STATS_BUFFERS.try_add_buffer(
layer_name=layer_name, layer_name=layer_name,
...@@ -150,10 +310,30 @@ class LogFp8TensorStats(BaseLogTensorStats): ...@@ -150,10 +310,30 @@ class LogFp8TensorStats(BaseLogTensorStats):
reduce_within_microbatch=reduce_within_microbatch, reduce_within_microbatch=reduce_within_microbatch,
) )
STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction) recipes_in_stats = [
self.get_recipe_from_stat(stat, default_recipe=recipe_name) for stat in config["stats"]
]
with self.update_aux_dict(
aux_dict={},
recipe_name=recipe_name,
quantized_tensor=quantized_tensor,
quantizer=quantizer,
original_tensor=tensor,
recipes_in_stats=recipes_in_stats,
) as aux_dict:
STATS_BUFFERS.feed(
layer_name,
tensor_name,
options,
tensor,
iteration,
skip_reduction,
aux_dict=aux_dict,
)
debug_api.log_message( debug_api.log_message(
f"Feature={self.__class__.__name__}, API=inspect_tensor_postquantize: {tensor_name}", f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}",
layer_name, layer_name,
extra_cachable_args=(tensor_name,), extra_cachable_args=(tensor_name,),
) )
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
"""LogTensorStats Feature support for nvidia-dlframework-inspect""" """LogTensorStats Feature support for nvidia-dlframework-inspect"""
from typing import Dict, Union from typing import Dict, Optional
import torch import torch
...@@ -12,13 +12,13 @@ from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as Bas ...@@ -12,13 +12,13 @@ from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as Bas
from nvdlfw_inspect.registry import Registry, api_method from nvdlfw_inspect.registry import Registry, api_method
import nvdlfw_inspect.api as debug_api import nvdlfw_inspect.api as debug_api
from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer
from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS
from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params
@Registry.register_feature(namespace="transformer_engine") @Registry.register_feature(namespace="transformer_engine")
...@@ -97,7 +97,15 @@ class LogTensorStats(BaseLogTensorStats): ...@@ -97,7 +97,15 @@ class LogTensorStats(BaseLogTensorStats):
self, config: Dict, layer_name: str, tensor_name: str, iteration: int self, config: Dict, layer_name: str, tensor_name: str, iteration: int
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""API call used to determine whether to run look_at_tensor_before_process() in the forward.""" """API call used to determine whether to run look_at_tensor_before_process() in the forward."""
return self._check_params(config, layer_name, iteration=iteration) run_current, next_iter = next_enabled_iter(
config.get("start_step", None),
config.get("end_step", None),
config.get("start_end_list", None),
config.get("freq", 1),
iteration,
)
STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter
return run_current, next_iter
@api_method @api_method
def inspect_tensor( def inspect_tensor(
...@@ -105,10 +113,13 @@ class LogTensorStats(BaseLogTensorStats): ...@@ -105,10 +113,13 @@ class LogTensorStats(BaseLogTensorStats):
config: Dict, config: Dict,
layer_name: str, layer_name: str,
tensor_name: str, tensor_name: str,
tensor: Union[torch.Tensor, QuantizedTensor],
iteration: int, iteration: int,
tp_group: torch.distributed.ProcessGroup, tp_group: torch.distributed.ProcessGroup,
): tensor: torch.Tensor,
rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None,
quantizer: Optional[Quantizer] = None,
): # pylint: disable=unused-argument
"""API call used to collect the data about the tensor before process_tensor()/quantization.""" """API call used to collect the data about the tensor before process_tensor()/quantization."""
assert ( assert (
...@@ -125,14 +136,9 @@ class LogTensorStats(BaseLogTensorStats): ...@@ -125,14 +136,9 @@ class LogTensorStats(BaseLogTensorStats):
config.get("start_end_list", None), config.get("start_end_list", None),
) )
skip_reduction = False skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params(
reduction_group = debug_api.get_tensor_reduction_group() tensor_name, tp_group
reduce_within_microbatch = tensor_name != "weight" )
if tensor_name == "weight":
if TEDebugState.weight_tensor_tp_group_reduce:
reduction_group = tp_group
else:
skip_reduction = True
for stat in config["stats"]: for stat in config["stats"]:
assert ( assert (
......
...@@ -91,14 +91,14 @@ class PerTensorScaling(TEConfigAPIMapper): ...@@ -91,14 +91,14 @@ class PerTensorScaling(TEConfigAPIMapper):
self, config, layer_name: str, gemm: str, iteration: int self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution.""" """API call responsible for selecting between high-precision and FP8 GEMM execution."""
return False return False, None
@api_method @api_method
def modify_tensor_enabled( def modify_tensor_enabled(
self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
"""API call used to determine whether to run process_tensor() in the forward.""" """API call used to determine whether to run process_tensor() in the forward."""
return True return True, iteration + 1
@api_method @api_method
def modify_tensor( def modify_tensor(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment