".circleci/unittest/vscode:/vscode.git/clone" did not exist on "8a03087ede7d5b58e6562e4d2ac78dc904303b56"
Commit 9df0c4a3 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents 0d874a4e f122b07d
......@@ -18,6 +18,7 @@
#include "../../util/vectorized_pointwise.h"
#include "../core/common.cuh"
#include "../fp8/quantize_fp8.cuh"
#include "../mxfp8/group_quantize_mxfp8.cuh"
#include "../mxfp8/quantize_mxfp8.cuh"
#include "../nvfp4/group_quantize_transpose_nvfp4.cuh"
#include "../nvfp4/quantize_nvfp4.cuh"
......@@ -381,6 +382,89 @@ void group_quantize_fwd_helper(const NVTETensor input, NVTETensor *outputs,
}
}
template <bool IS_ACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor output,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);
const NVTEGroupedTensor activation = nullptr;
NVTETensor dbias = nullptr;
NVTETensor workspace = nullptr;
const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize</*IS_DBIAS=*/false, /*IS_DACT=*/false, IS_ACT, ParamOP, OP>(
input_tensor, activations_tensor, noop_tensor, output_tensor, dbias_tensor,
workspace_tensor, stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
template <bool IS_DBIAS, bool IS_DACT, typename ParamOP, float (*OP)(float, const ParamOP &)>
void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace,
const NVTEQuantizationConfig quant_config, cudaStream_t stream) {
using namespace detail;
NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output);
const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad);
const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
Tensor *dbias_tensor = convertNVTETensor(dbias);
Tensor *workspace_tensor = convertNVTETensor(workspace);
// Quantization config
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
// Noop flag
Tensor dummy_tensor;
Tensor *noop_tensor = &dummy_tensor;
if (quant_config_cpp.noop_tensor != nullptr) {
noop_tensor = convertNVTETensorCheck(quant_config_cpp.noop_tensor);
}
// Dispatch to quantization kernel depending on data format
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
mxfp8::group_quantize<IS_DBIAS, IS_DACT, /*IS_ACT=*/false, ParamOP, OP>(
grad_tensor, input_tensor, noop_tensor, output_tensor, dbias_tensor, workspace_tensor,
stream);
break;
}
default:
NVTE_ERROR("Not implemented scaling mode: " + to_string(scaling_mode) + ".");
}
}
} // namespace dispatch
} // namespace transformer_engine
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file group_quantize_mxfp8.cuh
* \brief CUDA kernels to quantize grouped tensors to MXFP8.
*/
#ifndef TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
#define TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../common.h"
#include "../../util/math.h"
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
#include "swizzle.cuh"
namespace transformer_engine {
namespace dispatch {
namespace mxfp8 {
namespace group_quantize_kernel {
constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64;
__device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
__device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
__device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
__device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS];
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
constexpr size_t SCALE_DIM_Y = 32;
constexpr size_t SCALE_DIM_X = 32;
constexpr size_t BUFFS_NUM = 2;
constexpr size_t PACK_SIZE = 4;
constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE;
constexpr size_t CHUNK_DIM_Y = 128;
constexpr size_t CHUNK_DIM_X = 128;
constexpr size_t THREADS_PER_CHUNK = 128;
constexpr size_t ELTS_PER_CHUNK = CHUNK_DIM_Y * CHUNK_DIM_X;
constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X;
constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X;
constexpr size_t BUFF_DIM_Y = THREADS_Y;
constexpr size_t BUFF_DIM_X = CHUNK_DIM_X;
constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X;
static_assert(BUFF_DIM_Y == 32);
constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y;
static_assert(STAGES >= 1);
// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory
constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32
__device__ __forceinline__ size_t get_current_tensor_id(
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const __restrict__ offsets_ptr) {
if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) {
const size_t current_row = current_offset / last_logical_dim;
const size_t rows_per_tensor = first_logical_dim / num_tensors;
return current_row / rows_per_tensor;
} else {
size_t low = 1;
size_t hi = num_tensors; // [low, hi]
while (low < hi) {
const size_t mid = low + (hi - low) / 2;
const size_t mid_offset = static_cast<size_t>(offsets_ptr[mid]);
if (mid_offset <= current_offset) {
low = mid + 1;
} else {
hi = mid;
}
}
return low - 1;
}
}
__device__ __forceinline__ size_t get_tensor_rows_num(
const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t first_logical_dim,
const int64_t *const __restrict__ first_dims_ptr, const size_t num_tensors) {
size_t rows_num = 0;
switch (shape_rep) {
case ShapeRepresentation::SAME_BOTH_DIMS:
case ShapeRepresentation::VARYING_LAST_DIM:
rows_num = first_logical_dim;
break;
case ShapeRepresentation::VARYING_FIRST_DIM:
case ShapeRepresentation::VARYING_BOTH_DIMS:
rows_num = static_cast<size_t>(first_dims_ptr[tensor_id]);
break;
}
return rows_num;
}
__device__ __forceinline__ size_t get_tensor_cols_num(
const size_t tensor_id, const ShapeRepresentation shape_rep, const size_t last_logical_dim,
const int64_t *const __restrict__ last_dims_ptr) {
size_t cols_num = 0;
switch (shape_rep) {
case ShapeRepresentation::SAME_BOTH_DIMS:
case ShapeRepresentation::VARYING_FIRST_DIM:
cols_num = last_logical_dim;
break;
case ShapeRepresentation::VARYING_LAST_DIM:
case ShapeRepresentation::VARYING_BOTH_DIMS:
cols_num = static_cast<size_t>(last_dims_ptr[tensor_id]);
break;
}
return cols_num;
}
// Copies the base tensor map to shmem, modifies the copy, stores the modified tensor map at index
__device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_tensor_map,
CUtensorMap *global_tensor_map,
const uintptr_t global_data_ptr,
const size_t global_dim_Y,
const size_t global_dim_X,
const size_t data_type_size_bytes) {
__shared__ CUtensorMap shared_tensor_map;
shared_tensor_map = base_tensor_map; // Copy the base tensor map into shmem
constexpr bool is_blackwell = ARCH_BLACKWELL_FAMILY;
if constexpr (is_blackwell) {
const size_t global_stride_bytes = global_dim_X * data_type_size_bytes;
if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) {
NVTE_DEVICE_ERROR("Shape not supported, as data stride must be 16B aligned.");
}
if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) {
NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned");
}
asm volatile(
"{\n\t"
".reg.b64 tensor_map_ptr; \n\t"
"mov.b64 tensor_map_ptr, %0; \n\t"
"tensormap.replace.tile.global_address.b1024.b64 [tensor_map_ptr], %1; \n\t"
"tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 1, %2; \n\t" // DIM Y
"tensormap.replace.tile.global_dim.b1024.b32 [tensor_map_ptr], 0, %3; \n\t" // DIM X
"tensormap.replace.tile.global_stride.b1024.b64 [tensor_map_ptr], 0, %4; \n"
"}\n" ::"l"(reinterpret_cast<uintptr_t>(&shared_tensor_map)),
"l"(global_data_ptr), "r"(static_cast<uint32_t>(global_dim_Y)),
"r"(static_cast<uint32_t>(global_dim_X)), "l"(static_cast<uint64_t>(global_stride_bytes))
: "memory");
*global_tensor_map = shared_tensor_map;
} else {
NVTE_DEVICE_ERROR(
"tensormap.replace is architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
}
template <typename IType, typename OType>
__global__ void update_tma_descriptors(
const __grid_constant__ CUtensorMap base_tensor_map_input,
const __grid_constant__ CUtensorMap base_tensor_map_act_input,
const __grid_constant__ CUtensorMap base_tensor_map_output_rowwise,
const __grid_constant__ CUtensorMap base_tensor_map_output_colwise,
const IType *const __restrict__ input_data_ptr,
const IType *const __restrict__ act_input_data_ptr,
const OType *const __restrict__ output_rowwise_data_ptr,
const OType *const __restrict__ output_colwise_data_ptr, const ShapeRepresentation shape_rep,
const size_t num_tensors, const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const __restrict__ offsets_ptr, const int64_t *const __restrict__ first_dims_ptr,
const int64_t *const __restrict__ last_dims_ptr, const bool rowwise, const bool colwise,
const bool compute_dactivations) {
const bool leading_thread = (threadIdx.x == 0);
const size_t tensor_id = blockIdx.x;
const size_t rows =
get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
const size_t offset_elts = offsets_ptr[tensor_id];
if (leading_thread && (tensor_id < num_tensors)) {
{
const uintptr_t global_data_ptr = reinterpret_cast<uintptr_t>(input_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id],
global_data_ptr, rows, cols, sizeof(IType));
}
if (compute_dactivations) {
const uintptr_t global_data_ptr =
reinterpret_cast<uintptr_t>(act_input_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_act_input, &g_tensor_maps_act_input[tensor_id],
global_data_ptr, rows, cols, sizeof(IType));
}
if (rowwise) {
const uintptr_t global_data_ptr =
reinterpret_cast<uintptr_t>(output_rowwise_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_output_rowwise,
&g_tensor_maps_output_rowwise[tensor_id], global_data_ptr, rows, cols,
sizeof(OType));
}
if (colwise) {
const uintptr_t global_data_ptr =
reinterpret_cast<uintptr_t>(output_colwise_data_ptr + offset_elts);
modify_base_tensor_map(base_tensor_map_output_colwise,
&g_tensor_maps_output_colwise[tensor_id], global_data_ptr, rows, cols,
sizeof(OType));
}
}
}
__device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tensor_map) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("fence.proxy.tensormap::generic.acquire.cta [%0], 128;" ::"l"(tensor_map));
#else
NVTE_DEVICE_ERROR("fence_acquire_tensormap is only supported on SM 9.0+.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
}
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel(
const __grid_constant__ CUtensorMap tensor_map_input_static,
const __grid_constant__ CUtensorMap tensor_map_act_input_static,
const __grid_constant__ CUtensorMap tensor_map_output_rowwise_static,
const __grid_constant__ CUtensorMap tensor_map_output_colwise_static,
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t first_logical_dim,
const size_t last_logical_dim, const int64_t *const __restrict__ offsets_ptr,
const int64_t *const __restrict__ first_dims_ptr,
const int64_t *const __restrict__ last_dims_ptr, e8m0_t *const __restrict__ scales_rowwise_ptr,
e8m0_t *const __restrict__ scales_colwise_ptr, const float *__restrict__ noop,
float *const __restrict__ dbias_workspace, float *const __restrict__ amax_ptr) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT;
constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS;
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
}
constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING;
const size_t block_global_offset = blockIdx.x * ELTS_PER_CHUNK;
const size_t tensor_id = get_current_tensor_id(shape_rep, num_tensors, block_global_offset,
first_logical_dim, last_logical_dim, offsets_ptr);
const size_t rows =
get_tensor_rows_num(tensor_id, shape_rep, first_logical_dim, first_dims_ptr, num_tensors);
const size_t cols = get_tensor_cols_num(tensor_id, shape_rep, last_logical_dim, last_dims_ptr);
const size_t scale_stride_rowwise = DIVUP_TO_MULTIPLE(DIVUP(cols, static_cast<size_t>(32)), 4);
const size_t scale_stride_colwise = DIVUP_TO_MULTIPLE(cols, 128);
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM);
// grouped tensor can be treated as continuous tensor for MXFP8
const size_t tensor_base = is_single_tensor ? 0 : static_cast<size_t>(offsets_ptr[tensor_id]);
const CUtensorMap &tensor_map_input =
is_single_tensor ? tensor_map_input_static : g_tensor_maps_input[tensor_id];
const CUtensorMap &tensor_map_act_input =
is_single_tensor ? tensor_map_act_input_static : g_tensor_maps_act_input[tensor_id];
const CUtensorMap &tensor_map_output_rowwise =
is_single_tensor ? tensor_map_output_rowwise_static : g_tensor_maps_output_rowwise[tensor_id];
const CUtensorMap &tensor_map_output_colwise =
is_single_tensor ? tensor_map_output_colwise_static : g_tensor_maps_output_colwise[tensor_id];
const bool leading_thread = (threadIdx.x == 0);
if (leading_thread && (!is_single_tensor)) {
fence_acquire_tensormap(&tensor_map_input);
if constexpr (COMPUTE_ACTIVATIONS) {
fence_acquire_tensormap(&tensor_map_act_input);
}
if constexpr (ROWWISE_SCALING) {
fence_acquire_tensormap(&tensor_map_output_rowwise);
}
if constexpr (COLWISE_SCALING) {
fence_acquire_tensormap(&tensor_map_output_colwise);
}
}
const size_t blocks_X_num_in_current_tensor = DIVUP(cols, static_cast<size_t>(128));
const size_t block_id_in_current_tensor =
is_single_tensor ? blockIdx.x : (blockIdx.x - tensor_base / ELTS_PER_CHUNK);
const size_t block_id_Y = block_id_in_current_tensor / blocks_X_num_in_current_tensor;
const size_t block_id_X = block_id_in_current_tensor % blocks_X_num_in_current_tensor;
const size_t block_offset_Y = block_id_Y * CHUNK_DIM_Y;
const size_t block_offset_X = block_id_X * CHUNK_DIM_X;
e8m0_t *const scales_rowwise =
scales_rowwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_X);
e8m0_t *const scales_colwise =
scales_colwise_ptr + (is_single_tensor ? 0 : tensor_base / SCALE_DIM_Y);
const size_t scales_block_offset_Y_rowwise = block_id_Y * CHUNK_DIM_Y;
const size_t scales_block_offset_X_rowwise = block_id_X * CHUNK_DIM_X / SCALE_DIM_X;
const size_t scales_block_offset_Y_colwise = block_id_Y * CHUNK_DIM_Y / SCALE_DIM_Y;
const size_t scales_block_offset_X_colwise = block_id_X * CHUNK_DIM_X;
const size_t tid_Y_rowwise = threadIdx.x / THREADS_X;
const size_t tid_X_rowwise = threadIdx.x % THREADS_X;
const size_t tid_Y_colwise = 0;
const size_t tid_X_colwise = threadIdx.x;
const size_t thread_offset_Y_rowwise = tid_Y_rowwise;
const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X;
const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise;
const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise;
const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise;
const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise;
// helps resolving bank conflicts in shmem
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
extern __shared__ unsigned char dynamic_shmem[];
unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
IType *in_sh = reinterpret_cast<IType *>(dshmem);
IType *act_in_sh = reinterpret_cast<IType *>(dshmem + elt_input_mem);
OType *out_rowwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem);
OType *out_colwise_data_sh = reinterpret_cast<OType *>(dshmem + in_mem + out_mem_rowwise);
IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer
constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
float partial_dbias_colwise = 0.0f;
float thread_dbias_rowwise[SCALE_DIM_X];
if constexpr (IS_DBIAS) {
#pragma unroll
for (int j = 0; j < SCALE_DIM_X; ++j) {
thread_dbias_rowwise[j] = 0.0f;
}
}
float block_amax = 0.0f;
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
__shared__ alignas(8) uint64_t mbar[STAGES];
initialize_barriers<STAGES, THREADS_PER_CHUNK>(mbar, leading_thread);
int parity = 0;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0],
&tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], leading_thread);
} else {
copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size,
&mbar[0], leading_thread);
}
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const size_t buff = stage % BUFFS_NUM;
const size_t next_stage = stage + 1;
const size_t stage_offset_Y = stage * BUFF_DIM_Y;
if (next_stage < STAGES) {
// Wait for TMA transfer to have finished reading shared memory.
// I.e. the buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<1>();
const size_t next_buff = next_stage % BUFFS_NUM;
const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y;
const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y;
const size_t global_offset_X = block_offset_X;
const size_t next_buff_offset = next_buff * BUFF_DIM;
if constexpr (IS_DACT) {
copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input,
global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage],
leading_thread);
} else {
copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X,
global_offset_Y, shmem_buff_size, &mbar[next_stage], leading_thread);
}
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
thread_amax = 0.0f;
float in_compute_colwise[BUFF_DIM_Y];
IType in_colwise_IType[BUFF_DIM_Y];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType thread_amax_f16 = static_cast<IType>(0.0f);
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
in_colwise_IType[i] = in_sh[shmem_offset_colwise];
thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i]));
}
thread_amax = static_cast<float>(thread_amax_f16);
} else {
#pragma unroll
for (int i = 0; i < BUFF_DIM_Y; ++i) {
const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X;
float elt = static_cast<float>(in_sh[shmem_offset_colwise]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in_sh[shmem_offset_colwise]);
elt *= OP(act_in_elt, {});
}
if constexpr (IS_DBIAS) {
partial_dbias_colwise += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
// Cache computed activations to avoid computing them again in the 2nd pass along another dimension
if constexpr (IS_CACHED_ACT_OP) {
cached_act_sh[shmem_offset_colwise] = static_cast<IType>(elt);
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
in_compute_colwise[i] = elt;
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int i = 0; i < SCALE_DIM_Y; ++i) {
float in;
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = static_cast<float>(in_colwise_IType[i]);
} else {
in = in_compute_colwise[i];
}
const float scaled_out = in * block_scale_inverse;
const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X;
out_colwise_data_sh[shmem_offset_elt] = static_cast<OType>(scaled_out);
}
}
if constexpr (ROWWISE_SCALING) {
const size_t shmem_offset_base_rowwise =
buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X;
thread_amax = 0.0f;
float in_compute_rowwise[SCALE_DIM_X];
Vec<IType, PACK_SIZE> in_cached[WAVES];
// used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY
Vec<IType2, PACK_SIZE / 2> in_IType[WAVES];
// 1. Read/Compute elements. Find MXFP8-block AMAX
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load elements
in_IType[w].load_from(&in_sh[shmem_offset_rowwise]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]);
}
}
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
} else if constexpr (IS_CACHED_ACT_OP) {
// ensures that all writes to cache made in the section above are visible to all threads
__syncthreads();
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
// Load cached elements
in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]);
// Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements)
// only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries
if constexpr (std::is_same_v<IType, float>) {
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e]));
}
} else {
#pragma unroll
for (int e = 0; e < PACK_SIZE; e += 2) {
const IType2 in_cached_2x = {in_cached[w].data.elt[e], in_cached[w].data.elt[e + 1]};
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x);
}
}
}
if constexpr (!std::is_same_v<IType, float>) {
thread_amax =
static_cast<float>(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y)));
}
} else {
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx;
Vec<IType, PACK_SIZE> in;
Vec<IType, PACK_SIZE> act_in;
in.load_from(&in_sh[shmem_offset_rowwise]);
if constexpr (IS_DACT) {
act_in.load_from(&act_in_sh[shmem_offset_rowwise]);
}
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
// Compute element
float elt = static_cast<float>(in.data.elt[e]);
if constexpr (IS_ACT) {
elt = OP(elt, {});
}
if constexpr (IS_DACT) {
float act_in_elt = static_cast<float>(act_in.data.elt[e]);
elt *= OP(act_in_elt, {});
}
// If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again
if constexpr (IS_DBIAS && (!COLWISE_SCALING)) {
thread_dbias_rowwise[j] += elt;
}
// Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32
if constexpr (!std::is_same_v<IType, float>) {
elt = static_cast<float>(static_cast<IType>(elt));
}
thread_amax = fmaxf(thread_amax, fabsf(elt));
in_compute_rowwise[j] = elt;
}
}
}
// 2. Compute E8M0 scaling factor
const e8m0_t biased_exponent =
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
scales_rowwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse};
// 3. Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
Vec<OType2, PACK_SIZE / 2> out;
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
IType2 in;
OType2 &out_pair = reinterpret_cast<OType2 &>(out.data.elt[e]);
if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v<IType, float>)) {
in = in_IType[w].data.elt[e];
} else if constexpr (IS_CACHED_ACT_OP) {
in.x = in_cached[w].data.elt[2 * e];
in.y = in_cached[w].data.elt[2 * e + 1];
} else {
const int j = w * PACK_SIZE + 2 * e;
in.x = in_compute_rowwise[j];
in.y = in_compute_rowwise[j + 1];
}
ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x);
}
const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise;
const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx;
out.store_to(&out_rowwise_data_sh[shmem_offset_rowwise]);
}
}
__builtin_assume(block_amax >= 0);
__builtin_assume(thread_amax >= 0);
block_amax = fmaxf(block_amax, thread_amax);
// Wait for shared memory writes to be visible to TMA engine.
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine.
// Initiate TMA transfer to copy shared memory to global memory
if (leading_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X;
const int buff_offset = buff * BUFF_DIM;
if constexpr (ROWWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_rowwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_rowwise_data_sh[buff_offset]));
}
if constexpr (COLWISE_SCALING) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_colwise), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&out_colwise_data_sh[buff_offset]));
}
// Create a "bulk async-group" out of the previous bulk copy operation.
ptx::cp_async_bulk_commit_group();
}
}
parity ^= 1;
if constexpr (IS_DBIAS) {
if (is_single_tensor) {
float thread_partial_dbias = 0.0f;
if constexpr (COLWISE_SCALING) {
thread_partial_dbias = partial_dbias_colwise;
} else {
// Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH]
// HEIGHT = THREADS_Y
// WIDTH = THREADS_X * (SCALE_DIM_X + 1)
// Added extra 1-element padding per thread_X to reduce bank conflicts
float *partial_dbias_rowwise = reinterpret_cast<float *>(dshmem);
constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1);
const int shmem_thread_offset =
tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1);
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X;
const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx;
#pragma unroll
for (int e = 0; e < PACK_SIZE; ++e) {
const int j = w * PACK_SIZE + e;
const int shmem_elt_idx = swizzled_group_offset + e;
partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j];
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < THREADS_Y; ++i) {
// Add extra element offset per MXFP8 scaling block [1x32]
const int scaling_block = threadIdx.x / SCALE_DIM_X;
thread_partial_dbias +=
partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block];
}
}
const int dbias_stride = cols;
const int dbias_offset_Y = block_id_Y;
const int dbias_offset_X = block_id_X * CHUNK_DIM_X + threadIdx.x;
const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X;
const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols);
if (!col_out_of_bounds_dbias) {
dbias_workspace[dbias_idx] = thread_partial_dbias;
}
}
}
if (amax_ptr != nullptr) {
const int warp_id = threadIdx.x / THREADS_PER_WARP;
// Reduce the amax over the block
block_amax = reduce_max<THREADS_PER_CHUNK / THREADS_PER_WARP>(block_amax, warp_id);
}
if (leading_thread && amax_ptr != nullptr) {
atomicMaxFloat(amax_ptr, block_amax);
}
destroy_barriers<STAGES>(mbar, leading_thread);
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace group_quantize_kernel
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &)>
void group_quantize(const GroupedTensor *input, const GroupedTensor *activations,
const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace,
cudaStream_t stream) {
using namespace group_quantize_kernel;
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
const bool use_rowwise_scaling = output->has_data();
const bool use_colwise_scaling = output->has_columnwise_data();
NVTE_CHECK(use_rowwise_scaling || use_colwise_scaling,
"Either rowwise or columnwise output data need to be allocated.");
ScalingType scaling_type = ScalingType::BIDIMENSIONAL;
if (!use_colwise_scaling) {
scaling_type = ScalingType::ROWWISE;
} else if (!use_rowwise_scaling) {
scaling_type = ScalingType::COLWISE;
}
ShapeRepresentation shape_rep = ShapeRepresentation::SAME_BOTH_DIMS;
if (output->all_same_shape()) {
shape_rep = ShapeRepresentation::SAME_BOTH_DIMS;
} else if (output->all_same_first_dim()) {
shape_rep = ShapeRepresentation::VARYING_LAST_DIM;
} else if (output->all_same_last_dim()) {
shape_rep = ShapeRepresentation::VARYING_FIRST_DIM;
} else if (output->varying_both_dims()) {
shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS;
}
// Treat a grouped tensor with const last dims as a single tensor
const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS || shape_rep == VARYING_FIRST_DIM);
NVTE_CHECK(input->num_tensors == output->num_tensors,
"Number of input and output tensors must be same.");
NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type.");
if (IS_DACT) {
NVTE_CHECK(activations->has_data(), "Activations tensor must have data.");
NVTE_CHECK(input->num_tensors == activations->num_tensors,
"Number of grad and activations tensors must be same.");
NVTE_CHECK(input->dtype() == activations->dtype(),
"Grad and activations tensors must have the same type.");
}
const size_t first_logical_dim = input->logical_shape.data[0];
const size_t last_logical_dim = input->logical_shape.data[1];
const size_t elts_total = first_logical_dim * last_logical_dim;
const size_t num_tensors = input->num_tensors;
size_t blocks = 0;
if (is_single_tensor) {
const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y);
const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X);
blocks = blocks_Y * blocks_X;
} else {
NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS,
"Number of tensors in a group is larger than "
"the MAX number of supported descriptors (64).");
// Only full tiles supported
NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
}
const dim3 grid(blocks);
const size_t block_size = THREADS_PER_CHUNK;
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
// Logical shape of a tensor with varying all dims is [1, M*K]
if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) {
NVTE_CHECK(first_logical_dim % 128 == 0,
"First dimension of a grouped tensor should be divisible by 128.");
}
const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(input->tensor_offsets.dptr);
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(input->first_dims.dptr);
const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
e8m0_t *const scales_rowwise_ptr = reinterpret_cast<e8m0_t *>(output->scale_inv.dptr);
e8m0_t *const scales_colwise_ptr = reinterpret_cast<e8m0_t *>(output->columnwise_scale_inv.dptr);
if (use_rowwise_scaling) {
NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated");
}
if (use_colwise_scaling) {
NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated");
}
const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y);
const size_t dbias_cols = last_logical_dim;
if constexpr (IS_DBIAS) {
NVTE_CHECK(is_single_tensor,
"DBias is only supported for tensors with the const last dimension.");
NVTE_CHECK(dbias->data.dtype == input->dtype(),
"DBias must have the same type as input_tensor.");
NVTE_CHECK(dbias->data.shape == std::vector<size_t>{last_logical_dim}, "Wrong shape of DBias.");
NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor.");
if (workspace->data.dptr == nullptr) {
workspace->data.shape = {dbias_rows, dbias_cols};
workspace->data.dtype = DType::kFloat32;
return;
}
}
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
auto kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::COLWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
}
// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);
const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;
OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;
OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr,
use_rowwise_scaling, use_colwise_scaling, IS_DACT);
}
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace mxfp8
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GROUP_QUANTIZE_MXFP8_CUH_
......@@ -25,6 +25,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "core_nvfp4.cuh"
#include "specialized/quantize_transpose_nvfp4_tuned_1D.cuh"
namespace transformer_engine {
namespace dispatch {
......@@ -1163,6 +1164,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
#if FP4_TYPE_SUPPORTED
using namespace quantize_transpose_kernel;
using namespace ptx;
bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
// If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to
......@@ -1170,6 +1172,11 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output,
// TODO(Frank): Is there a better way to do this?
bool return_transpose = output->has_columnwise_data();
if (!use_2d_quantization && (input.dtype() == DType::kBFloat16)) {
quantize_transpose_tuned_1D(input, noop, output, quant_config, stream);
return;
}
constexpr bool COMPUTE_ACTIVATIONS = false;
using ParamOP = Empty;
constexpr float (*OP)(float, const ParamOP &) = nullptr;
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
/*! \file quantize_transpose_nvfp4_tuned_1D.cuh
* \brief Tuned kernel to cast to NVFP4 and transpose.
*/
#ifndef TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_
#define TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_runtime.h>
#include <transformer_engine/transformer_engine.h>
#include "../../../common.h"
#include "../../../util/math.h"
#include "../../../util/ptx.cuh"
#include "../../../utils.cuh"
#include "../core_nvfp4.cuh"
namespace transformer_engine {
namespace dispatch {
namespace nvfp4 {
namespace quantize_transpose_tuned_kernel {
using namespace quantization_and_transposition_SF;
using namespace core;
using namespace ptx;
#if FP4_TYPE_SUPPORTED
struct TunableConfig {
static constexpr int CHUNK_DIM_Y = 128;
static constexpr int CHUNK_DIM_X = 128;
static constexpr int PREFETCH_STAGES = 1;
static constexpr bool PERSISTENT = false;
};
constexpr int SCALE_DIM = 16; // NVFP4 block (x16 elts)
constexpr int THREADS_NUM = 128;
constexpr int ELTS_PER_THREAD = 16;
constexpr int TILE_DIM_Y = 64;
constexpr int TILE_DIM_X = 64;
static_assert(ELTS_PER_THREAD == SCALE_DIM && "Hardcoded and fixed parameter\0");
static_assert((THREADS_NUM * ELTS_PER_THREAD <= TILE_DIM_Y * TILE_DIM_X) &&
"Unbalanced threads workload\0");
static_assert((TunableConfig::CHUNK_DIM_Y % TILE_DIM_Y == 0) &&
"Chunk size Y must be evenly divisible by the tile size Y\0");
static_assert((TunableConfig::CHUNK_DIM_X % TILE_DIM_X == 0) &&
"Chunk size X must be evenly divisible by the tile size X\0");
static_assert((TILE_DIM_Y % SCALE_DIM == 0) &&
"Tile size Y must be evenly divisible by the scale dim\0");
static_assert((TILE_DIM_X % SCALE_DIM == 0) &&
"Tile size X must be evenly divisible by the scale dim\0");
constexpr int TILES_Y = TunableConfig::CHUNK_DIM_Y / TILE_DIM_Y;
constexpr int TILES_X = TunableConfig::CHUNK_DIM_X / TILE_DIM_X;
constexpr int THREADS_PER_SCALE_ROWWISE = SCALE_DIM / ELTS_PER_THREAD;
constexpr int SCALES_PER_CHUNK_Y = TunableConfig::CHUNK_DIM_Y / SCALE_DIM;
constexpr int SCALES_PER_CHUNK_X = TunableConfig::CHUNK_DIM_X / SCALE_DIM;
constexpr int SCALES_PER_TILE_Y = TILE_DIM_Y / SCALE_DIM;
constexpr int SCALES_PER_TILE_X = TILE_DIM_X / SCALE_DIM;
constexpr int STAGES_Y = TILES_Y;
constexpr int STAGES_X = TILES_X;
constexpr int STAGES = STAGES_Y * STAGES_X;
constexpr int BUFFS_NUM = TunableConfig::PREFETCH_STAGES + 1;
constexpr int BUFFS_NUM_IN = BUFFS_NUM;
constexpr int BUFFS_NUM_OUT = BUFFS_NUM;
constexpr int BUFFS_NUM_OUT_TR = 2;
constexpr int BUFF_DIM_Y = TILE_DIM_Y;
constexpr int BUFF_DIM_X = TILE_DIM_X;
constexpr int BUFF_SIZE = BUFF_DIM_Y * BUFF_DIM_X;
constexpr int BUFF_SIZE_TOTAL = BUFF_SIZE * BUFFS_NUM;
// Input buffer (BF16)
constexpr int BUFF_IN_DIM_Y = BUFF_DIM_Y;
constexpr int BUFF_IN_DIM_X = BUFF_DIM_X;
constexpr int BUFF_IN_SIZE = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
constexpr int BUFF_IN_ELTS_NUM = BUFF_IN_DIM_Y * BUFF_IN_DIM_X;
// Output buffer (NVFP4)
constexpr int BUFF_OUT_DIM_Y = BUFF_DIM_Y;
constexpr int BUFF_OUT_DIM_X = (BUFF_DIM_X * 4) / 8;
constexpr int BUFF_OUT_SIZE = BUFF_OUT_DIM_Y * BUFF_OUT_DIM_X;
// Output transpose buffer (NVFP4)
constexpr int BUFF_OUT_TR_DIM_Y = BUFF_DIM_X;
constexpr int BUFF_OUT_TR_DIM_X = (BUFF_DIM_Y * 4) / 8;
constexpr int BUFF_OUT_TR_SIZE = BUFF_OUT_TR_DIM_Y * BUFF_OUT_TR_DIM_X;
// Manual swizzling parameters to reduce SHMEM bank conflicts
constexpr int PACK_SIZE = 8;
constexpr int WAVES = ELTS_PER_THREAD / PACK_SIZE;
constexpr int THREADS_X_ROWWISE = TILE_DIM_X / ELTS_PER_THREAD;
constexpr int THREADS_Y_ROWWISE = THREADS_NUM / THREADS_X_ROWWISE;
constexpr int THREADS_X_TR = TILE_DIM_X / 2;
constexpr int THREADS_Y_TR = THREADS_NUM / THREADS_X_TR;
constexpr int ITERATIONS_NORMAL = BUFF_DIM_Y / THREADS_Y_ROWWISE;
constexpr int ITERATIONS_TR = SCALES_PER_TILE_Y / THREADS_Y_TR;
static_assert(ITERATIONS_TR >= 1 && "Number of transpose iterations should be >=1\0");
static_assert((SCALES_PER_TILE_Y % THREADS_Y_TR == 0) &&
"Partial transpose iterations are not supported\0");
constexpr int BUFF_OUT_IT_OFFSET = BUFF_OUT_TR_DIM_X / ITERATIONS_TR / STAGES;
static_assert(BUFF_DIM_Y >= SCALE_DIM &&
"Number of buffer rows must be greater or equal to the size of the columwise "
"scaling block\0");
static_assert(TunableConfig::CHUNK_DIM_Y >= BUFF_DIM_Y);
static_assert(BUFF_DIM_Y >= THREADS_Y_ROWWISE &&
"Number of buffer rows must be greater or equal to the number of rowwise "
"processing threads in Y dimension\0");
// Number of 4-bit elements that span 32 banks (4-byte each) of shared memory
constexpr int TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256
// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory
constexpr int THREADS_PER_BANK = TOTAL_BANKS_WIDTH / ELTS_PER_THREAD;
using IType = bf16;
using IType2 = typename ptx::FPx2<IType>;
using IType3D = IType[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X];
using IType2x3D = IType2[BUFFS_NUM_IN][BUFF_IN_DIM_Y][BUFF_IN_DIM_X / 2];
using OType2x3D = fp4e2m1x2[BUFFS_NUM_OUT][BUFF_OUT_DIM_Y][BUFF_OUT_DIM_X];
using OType2xt3D = fp4e2m1x2[BUFFS_NUM_OUT_TR][BUFF_OUT_TR_DIM_Y][BUFF_OUT_TR_DIM_X];
using ScalesType2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_Y][SCALES_PER_CHUNK_X];
using ScalesTypeTr2D = nvfp4_scale_t[TunableConfig::CHUNK_DIM_X][SCALES_PER_CHUNK_Y];
using RNG_t = typename transformer_engine::curanddx::detail::philox4x32_native_state<10>;
template <bool USE_FAST_MATH>
struct SCALING_COEFFICIENT_TYPE {};
template <>
struct SCALING_COEFFICIENT_TYPE<false> {
using type = float;
};
template <>
struct SCALING_COEFFICIENT_TYPE<true> {
using type = bf16;
};
__device__ __forceinline__ float get_amax_of_pair(const IType2 pair) {
return static_cast<float>(__hmax(__habs(pair.x), __habs(pair.y)));
}
// Compute "correct" per-block encoding scaling factor
template <typename SF_TYPE>
__device__ __forceinline__ SF_TYPE
compute_nvfp4_scaling_coefficient(const nvfp4_scale_t S_dec_block, const float S_enc) {
NVTE_DEVICE_ERROR("Unsupported scaling-factor type. Only FP32 and BF16 are supported.");
}
template <>
__device__ __forceinline__ float compute_nvfp4_scaling_coefficient<float>(
const nvfp4_scale_t S_dec_block, const float S_enc) {
const float S_dec = 1.0f / S_enc;
const float scale_rcp =
fminf(1.0f / (static_cast<float>(S_dec_block) * S_dec), detail::TypeExtrema<float>::max);
return scale_rcp;
}
template <>
__device__ __forceinline__ bf16
compute_nvfp4_scaling_coefficient<bf16>(const nvfp4_scale_t S_dec_block, const float S_enc) {
const float scale_rcp =
fminf(S_enc / (static_cast<float>(S_dec_block)), detail::TypeExtrema<bf16>::max);
return static_cast<bf16>(scale_rcp);
}
template <bool USE_STOCHASTIC_ROUNDING, bool USE_FAST_MATH>
__device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_ptr,
fp4e2m1x2 *__restrict__ sOut_tr_ptr,
nvfp4_scale_t *__restrict__ sSFcolwise_ptr,
const float S_enc_colwise, const int stage_Y,
const int stage_X, const int buff_in,
const int buff_out_tr, RNG_t &rng,
uint4 &random_uint4, int &rnd_idx) {
using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE<USE_FAST_MATH>::type;
const auto &sIn2x = *reinterpret_cast<const IType2x3D *>(sIn_ptr);
auto &sOut_tr = *reinterpret_cast<OType2xt3D *>(sOut_tr_ptr);
auto &sSFcolwise = *reinterpret_cast<ScalesTypeTr2D *>(sSFcolwise_ptr);
const int warp = threadIdx.x / THREADS_PER_WARP;
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int tid_Y_colwise = (thread_lane % 4 + warp) % 4;
const int tid_X_colwise = thread_lane;
const int thread_offset_Y_colwise = tid_Y_colwise * SCALE_DIM;
const int thread_offset_X_colwise = tid_X_colwise * 2;
const int in_thread_offset_Y = thread_offset_Y_colwise;
const int in_thread_offset_X = thread_offset_X_colwise / 2;
const int out_tr_thread_offset_Y = thread_offset_X_colwise;
const int out_tr_thread_offset_X = thread_offset_Y_colwise / 2;
const int scale_tr_offset_Y = (stage_X * TILE_DIM_X) + 2 * tid_X_colwise;
const int scale_tr_offset_X = (stage_Y * SCALES_PER_TILE_Y) + tid_Y_colwise;
__align__(8) IType rIn[2][SCALE_DIM];
// Read (cache) a pair of input elements (S2R). Find NVFP4-block AMAX
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int i = 0; i < SCALE_DIM; ++i) {
const IType2 elt_pair =
ptx::ld_shared_b32(&sIn2x[buff_in][in_thread_offset_Y + i][in_thread_offset_X]);
rIn[0][i] = elt_pair.x;
rIn[1][i] = elt_pair.y;
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, elt_pair);
}
const float block_amax[2] = {static_cast<float>(__habs(thread_amax_2x.x)),
static_cast<float>(__habs(thread_amax_2x.y))};
#pragma unroll
for (int w = 0; w < 2; ++w) {
const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax[w], S_enc_colwise);
// Store scaling factors to SMEM buffer (R2S)
sSFcolwise[scale_tr_offset_Y + w][scale_tr_offset_X] = S_dec_b_fp8;
const scaling_coeff_type SFcoefficient =
compute_nvfp4_scaling_coefficient<scaling_coeff_type>(S_dec_b_fp8, S_enc_colwise);
// Scale elements
__align__(8) uint32_t rOut[SCALE_DIM / 8];
#pragma unroll
for (int e = 0; e < SCALE_DIM / 8; ++e) {
const uint64_t elts03 = *reinterpret_cast<uint64_t *>(&rIn[w][8 * e]);
const uint64_t elts47 = *reinterpret_cast<uint64_t *>(&rIn[w][8 * e + 4]);
if constexpr (USE_STOCHASTIC_ROUNDING) {
const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx);
const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx);
rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding<scaling_coeff_type>(
elts03, elts47, SFcoefficient, rbits03, rbits47);
} else {
rOut[e] = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest<scaling_coeff_type>(elts03, elts47,
SFcoefficient);
}
}
uint64_t &out_pack_16x = *reinterpret_cast<uint64_t *>(rOut);
ptx::st_shared_b64(&sOut_tr[buff_out_tr][out_tr_thread_offset_Y + w][out_tr_thread_offset_X],
out_pack_16x);
}
}
template <bool USE_STOCHASTIC_ROUNDING, bool USE_FAST_MATH>
__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr,
fp4e2m1x2 *__restrict__ sOut_ptr,
nvfp4_scale_t *__restrict__ sSFrowwise_ptr,
const float S_enc_rowwise, const int stage_Y,
const int stage_X, const int buff_in,
const int buff_out, RNG_t &rng, uint4 &random_uint4,
int &rnd_idx) {
using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE<USE_FAST_MATH>::type;
const auto &sIn = *reinterpret_cast<const IType3D *>(sIn_ptr);
auto &sOut = *reinterpret_cast<OType2x3D *>(sOut_ptr);
auto &sSFrowwise = *reinterpret_cast<ScalesType2D *>(sSFrowwise_ptr);
const int thread_lane = threadIdx.x % THREADS_PER_WARP;
const int bank_group = thread_lane / THREADS_PER_BANK;
const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE;
const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE;
const int thread_offset_Y_rowwise = tid_Y_rowwise;
const int thread_offset_X_rowwise = tid_X_rowwise * ELTS_PER_THREAD;
const int SF_thread_offset_rowwise_Y = tid_Y_rowwise;
const int SF_thread_offset_rowwise_X = tid_X_rowwise / THREADS_PER_SCALE_ROWWISE;
const bool SF_storing_thread = (tid_X_rowwise % THREADS_PER_SCALE_ROWWISE == 0);
const int stage_rowwise_scales_offset_Y = SF_thread_offset_rowwise_Y + stage_Y * TILE_DIM_Y;
const int stage_rowwise_scales_offset_X =
SF_thread_offset_rowwise_X + stage_X * SCALES_PER_TILE_X;
#pragma unroll
for (int it = 0; it < ITERATIONS_NORMAL; ++it) {
const int it_offset_Y_rowwise = thread_offset_Y_rowwise + it * THREADS_Y_ROWWISE;
__align__(16) IType2 rIn[WAVES][PACK_SIZE / 2];
// Read (cache) input elements (S2R). Find NVFP4-block AMAX
IType2 thread_amax_2x = {static_cast<IType>(0.0f), static_cast<IType>(0.0f)};
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD;
const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx;
// Load elements
__uint128_t &elts_8x = *reinterpret_cast<__uint128_t *>(&rIn[w]);
elts_8x = ptx::ld_shared_b128(&sIn[buff_in][it_offset_Y_rowwise][swizzled_thread_idx]);
#pragma unroll
for (int e = 0; e < PACK_SIZE / 2; ++e) {
ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, rIn[w][e]);
}
}
const float block_amax = get_amax_of_pair(thread_amax_2x);
const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise);
const scaling_coeff_type SFcoefficient =
compute_nvfp4_scaling_coefficient<scaling_coeff_type>(S_dec_b_fp8, S_enc_rowwise);
// Store scaling factors to SMEM buffer (R2S)
if (SF_storing_thread) {
const int scales_offset_Y = stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE;
const int scales_offset_X = stage_rowwise_scales_offset_X;
sSFrowwise[scales_offset_Y][scales_offset_X] = S_dec_b_fp8;
}
// Scale elements
#pragma unroll
for (int w = 0; w < WAVES; ++w) {
const uint64_t elts03 = *reinterpret_cast<uint64_t *>(&rIn[w][0]);
const uint64_t elts47 = *reinterpret_cast<uint64_t *>(&rIn[w][2]);
uint32_t out_x8;
if constexpr (USE_STOCHASTIC_ROUNDING) {
const uint32_t rbits03 = core::get_rbits(rng, random_uint4, rnd_idx);
const uint32_t rbits47 = core::get_rbits(rng, random_uint4, rnd_idx);
out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_stochastic_rounding<scaling_coeff_type>(
elts03, elts47, SFcoefficient, rbits03, rbits47);
} else {
out_x8 = ptx::mul_cvt_bf16_to_fp4_8x_round_to_nearest<scaling_coeff_type>(elts03, elts47,
SFcoefficient);
}
const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % ELTS_PER_THREAD;
const int swizzled_idx = (swizzled_group_idx + thread_offset_X_rowwise) / 2;
ptx::st_shared_b32(&sOut[buff_out][it_offset_Y_rowwise][swizzled_idx], out_x8);
}
}
}
template <bool USE_STOCHASTIC_ROUNDING, bool USE_FAST_MATH, bool RETURN_TRANSPOSE>
__global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel(
const __grid_constant__ CUtensorMap tensor_map_input,
const __grid_constant__ CUtensorMap tensor_map_output,
const __grid_constant__ CUtensorMap tensor_map_output_t, nvfp4_scale_t *const scales_ptr,
nvfp4_scale_t *const scales_t_ptr, const float *noop, const float *const amax_rowwise_ptr,
const float *const amax_colwise_ptr, const size_t rows, const size_t cols,
const size_t scale_stride, const size_t scale_stride_t, const size_t *rng_state) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
if (noop != nullptr && noop[0] == 1.0f) {
return;
}
const size_t rng_sequence =
threadIdx.x + blockIdx.x * THREADS_NUM + blockIdx.y * gridDim.x * THREADS_NUM;
const size_t rng_seed = rng_state != nullptr ? rng_state[0] : 0;
const size_t rng_offset = rng_state != nullptr ? rng_state[1] : 0;
RNG_t rng;
rng.init(rng_seed, rng_sequence, rng_offset);
uint4 random_uint4 = USE_STOCHASTIC_ROUNDING ? rng.generate4() : uint4{0, 0, 0, 0};
// Index of the random number. It increments each time when used and resets to 0 if reaches 4x
int rnd_idx = 0;
const bool leading_thread = (threadIdx.x == 0);
constexpr int buff_elems = BUFF_DIM_Y * BUFF_IN_DIM_X;
constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems;
constexpr int buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr int buff_size_aligned_out =
DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT);
constexpr int buff_size_aligned_out_t =
DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT);
constexpr int in_mem = buff_size_aligned_in;
constexpr int out_mem_rowwise_data = buff_size_aligned_out;
constexpr int out_mem_colwise_data = RETURN_TRANSPOSE ? buff_size_aligned_out_t : 0;
constexpr int out_mem_rowwise_scales = DIVUP_TO_MULTIPLE(
TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
extern __shared__ unsigned char dynamic_shmem[];
unsigned char *dshmem = common::align_smem_ptr_per_TMA_requirements(dynamic_shmem);
IType *sIn_ptr = reinterpret_cast<IType *>(dshmem);
fp4e2m1x2 *sOut_ptr = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem);
fp4e2m1x2 *sOut_tr_ptr = reinterpret_cast<fp4e2m1x2 *>(dshmem + in_mem + out_mem_rowwise_data);
auto &sIn = *reinterpret_cast<IType3D *>(sIn_ptr);
auto &sOut = *reinterpret_cast<OType2x3D *>(sOut_ptr);
auto &sOut_tr = *reinterpret_cast<OType2xt3D *>(sOut_tr_ptr);
nvfp4_scale_t *sSFrowwise_ptr = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data);
nvfp4_scale_t *sSFcolwise_ptr = reinterpret_cast<nvfp4_scale_t *>(
dshmem + in_mem + out_mem_rowwise_data + out_mem_colwise_data + out_mem_rowwise_scales);
auto &sSFrowwise = *reinterpret_cast<ScalesType2D *>(sSFrowwise_ptr);
auto &sSFcolwise = *reinterpret_cast<ScalesTypeTr2D *>(sSFcolwise_ptr);
constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM;
// Compute a global encoding/decoding scaling factors for all S_dec_b
const float S_enc_rowwise =
(amax_rowwise_ptr == nullptr)
? 1.0f
: core::compute_global_encode_scaling_factor_FP4(*amax_rowwise_ptr);
const float S_enc_colwise =
(amax_colwise_ptr == nullptr)
? S_enc_rowwise
: core::compute_global_encode_scaling_factor_FP4(*amax_colwise_ptr);
__shared__ uint64_t workID_mbar;
__shared__ __uint128_t workID_response;
constexpr uint32_t workID_response_size = sizeof(workID_response);
static_assert(workID_response_size == 16);
__shared__ uint64_t IN_buff_readable_mbar[BUFFS_NUM];
// Coordinates of the first chunk (CTA) to process
int32_t ctaid_X = blockIdx.x;
int32_t ctaid_Y = blockIdx.y;
// Initialize shared memory barriers with the number of threads participating in them
if (leading_thread) {
#pragma unroll
for (int buff = 0; buff < BUFFS_NUM; ++buff) {
ptx::mbarrier_init(&IN_buff_readable_mbar[buff], 1);
}
ptx::mbarrier_init(&workID_mbar, 1);
ptx::fence_proxy_async_shared_cta();
}
__syncthreads();
bool job_finished = false;
int buff_in = 0;
int buff_out = 0;
int buff_out_tr = 0;
int IN_buff_readable_parity[BUFFS_NUM] = {0, 0};
int ctaid_parity = 0;
// Prefetch input data only when processing the first chunk,
// which enables the one-iteration overlap throughout the entire kernel life
#pragma unroll
for (int stage = 0; stage < TunableConfig::PREFETCH_STAGES; ++stage) {
const int buff_in = stage;
const int stage_Y = stage / STAGES_X;
const int stage_X = stage % STAGES_X;
const int stage_offset_Y = stage_Y * TILE_DIM_Y;
const int stage_offset_X = stage_X * TILE_DIM_X;
const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y;
const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X;
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X + stage_offset_X;
uint64_t *barrier = &IN_buff_readable_mbar[buff_in];
if (leading_thread) {
uint64_t *dst = reinterpret_cast<uint64_t *>(&sIn[buff_in]);
const uint64_t *src = reinterpret_cast<const uint64_t *>(&tensor_map_input);
// Arrive on the barrier and tell how many bytes are expected to come in
ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size);
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y,
barrier);
}
}
while (!job_finished) {
const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y;
const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X;
const int block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X;
const int block_offset_X_tr = ctaid_Y * TunableConfig::CHUNK_DIM_Y;
const int chunk_rows = rows - block_offset_Y;
const int chunk_cols = cols - block_offset_X;
const int scales_block_offset_Y_rowwise = ctaid_Y * TunableConfig::CHUNK_DIM_Y;
const int scales_block_offset_X_rowwise = ctaid_X * SCALES_PER_CHUNK_X;
const int scales_block_offset_Y_tr = ctaid_X * TunableConfig::CHUNK_DIM_X;
const int scales_block_offset_X_tr = ctaid_Y * SCALES_PER_CHUNK_Y;
if constexpr (TunableConfig::PERSISTENT) {
if (leading_thread) {
ptx::mbarrier_arrive_expect_tx_cta_relaxed_shared_cta(&workID_mbar, workID_response_size);
ptx::try_cancel_cta(&workID_mbar, &workID_response);
}
}
#pragma unroll
for (int stage = 0; stage < STAGES; ++stage) {
const int stage_Y = stage / STAGES_X;
const int stage_X = stage % STAGES_X;
const int stage_offset_Y = stage_Y * TILE_DIM_Y;
const int stage_offset_X = stage_X * TILE_DIM_X;
if (stage == STAGES - TunableConfig::PREFETCH_STAGES) {
if constexpr (TunableConfig::PERSISTENT) {
ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&workID_mbar, ctaid_parity);
ptx::get_cancelled_cta_id_2D(&workID_response, ctaid_X, ctaid_Y);
ctaid_parity ^= 1;
} else {
ctaid_X = -1;
ctaid_Y = -1;
}
if (ctaid_X == -1 && ctaid_Y == -1) {
job_finished = true;
}
}
// Prefetch next stage Input data
if (!job_finished || (stage < STAGES - TunableConfig::PREFETCH_STAGES)) {
const int next_prefetch_buff = (buff_in + TunableConfig::PREFETCH_STAGES) % BUFFS_NUM;
const int next_prefetch_stage = (stage + TunableConfig::PREFETCH_STAGES) % STAGES;
const int next_prefetch_stage_Y = next_prefetch_stage / STAGES_X;
const int next_prefetch_stage_X = next_prefetch_stage % STAGES_X;
const int next_prefetch_stage_offset_Y = next_prefetch_stage_Y * TILE_DIM_Y;
const int next_prefetch_stage_offset_X = next_prefetch_stage_X * TILE_DIM_X;
// Offsets change, because coordinates of the next "to-be-prefetched" CTA do also chage
const int block_offset_Y = ctaid_Y * TunableConfig::CHUNK_DIM_Y;
const int block_offset_X = ctaid_X * TunableConfig::CHUNK_DIM_X;
const int global_offset_Y = block_offset_Y + next_prefetch_stage_offset_Y;
const int global_offset_X = block_offset_X + next_prefetch_stage_offset_X;
uint64_t *barrier = &IN_buff_readable_mbar[next_prefetch_buff];
if (leading_thread) {
uint64_t *dst = reinterpret_cast<uint64_t *>(&sIn[next_prefetch_buff]);
const uint64_t *src = reinterpret_cast<const uint64_t *>(&tensor_map_input);
// Arrive on the barrier and tell how many bytes are expected to come in
ptx::mbarrier_arrive_expect_tx(barrier, shmem_buff_size);
// Initiate bulk tensor copy
ptx::cp_async_bulk_tensor_2d_global_to_shared(dst, src, global_offset_X, global_offset_Y,
barrier);
}
ptx::fence_proxy_async_shared_cta();
}
// Wait for the data to have arrived
ptx::mbarrier_wait_parity_acquire_cta_shared_cta(&IN_buff_readable_mbar[buff_in],
IN_buff_readable_parity[buff_in]);
IN_buff_readable_parity[buff_in] ^= 1;
// Wait for TMA transfer to have finished reading shared memory
// I.e. the OUT buffer is ready to be written to
ptx::cp_async_bulk_wait_group_read<TunableConfig::PREFETCH_STAGES>();
// NVFP4 Quantization
rowwise_scaling<USE_STOCHASTIC_ROUNDING, USE_FAST_MATH>(
sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out,
rng, random_uint4, rnd_idx);
if constexpr (RETURN_TRANSPOSE) {
colwise_scaling<USE_STOCHASTIC_ROUNDING, USE_FAST_MATH>(
sIn_ptr, sOut_tr_ptr, sSFcolwise_ptr, S_enc_colwise, stage_Y, stage_X, buff_in,
buff_out_tr, rng, random_uint4, rnd_idx);
}
// Wait for shared memory writes to be visible to TMA engine
ptx::fence_proxy_async_shared_cta();
__syncthreads();
// After syncthreads, writes by all threads are visible to TMA engine
// Initiate TMA transfer to copy shared memory to global memory
if (leading_thread) {
const int global_offset_Y = block_offset_Y + stage_offset_Y;
const int global_offset_X = block_offset_X + stage_offset_X;
const int global_offset_Y_tr = block_offset_Y_tr + stage_offset_X;
const int global_offset_X_tr = block_offset_X_tr + stage_offset_Y;
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output), global_offset_X,
global_offset_Y, reinterpret_cast<uint64_t *>(&sOut[buff_out]));
if constexpr (RETURN_TRANSPOSE) {
ptx::cp_async_bulk_tensor_2d_shared_to_global(
reinterpret_cast<const uint64_t *>(&tensor_map_output_t), global_offset_X_tr,
global_offset_Y_tr, reinterpret_cast<uint64_t *>(&sOut_tr[buff_out_tr]));
}
// Create a "bulk async-group" out of the previous bulk copy operation
ptx::cp_async_bulk_commit_group();
}
buff_in = (buff_in + 1) % BUFFS_NUM_IN;
buff_out = (buff_out + 1) % BUFFS_NUM_OUT;
buff_out_tr = (buff_out_tr + 1) % BUFFS_NUM_OUT_TR;
} // end of stages
// Vectorized store of scaling factors (S2G)
{
// Rowwise
{
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_X>;
// number of scales in X dimension of this chunk
const int count = min(SCALES_PER_CHUNK_X, chunk_cols / SCALE_DIM);
for (size_t row = threadIdx.x; row < TunableConfig::CHUNK_DIM_Y; row += THREADS_NUM) {
const size_t row_global = scales_block_offset_Y_rowwise + row;
if (row_global < rows) {
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(sSFrowwise[row]);
const size_t scale_idx_global =
row_global * scale_stride + scales_block_offset_X_rowwise;
scales_vec.store_to_elts(&scales_ptr[scale_idx_global], 0, count);
}
}
}
// Colwise
if constexpr (RETURN_TRANSPOSE) {
using ScalesVec = Vec<nvfp4_scale_t, SCALES_PER_CHUNK_Y>;
// number of scales in Y dimension of this chunk
const int count = min(SCALES_PER_CHUNK_Y, chunk_rows / SCALE_DIM);
for (size_t row_tr = threadIdx.x; row_tr < TunableConfig::CHUNK_DIM_X;
row_tr += THREADS_NUM) {
const size_t row_tr_global = scales_block_offset_Y_tr + row_tr;
if (row_tr_global < cols) {
ScalesVec &scales_vec = *reinterpret_cast<ScalesVec *>(sSFcolwise[row_tr]);
const size_t scale_idx_global =
row_tr_global * scale_stride_t + scales_block_offset_X_tr;
scales_vec.store_to_elts(&scales_t_ptr[scale_idx_global], 0, count);
}
}
}
if (!job_finished) {
// Ensures all reads from SFs buffer have completed and it's ready to be reused
__syncthreads();
}
}
}
if (leading_thread) {
#pragma unroll
for (int buff = 0; buff < BUFFS_NUM; ++buff) {
ptx::mbarrier_invalid(&IN_buff_readable_mbar[buff]);
}
ptx::mbarrier_invalid(&workID_mbar);
}
#else
NVTE_DEVICE_ERROR("sm_100 or higher is required.");
#endif // (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
#endif // FP4_TYPE_SUPPORTED
} // namespace quantize_transpose_tuned_kernel
inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, Tensor *output,
const QuantizationConfig *quant_config,
cudaStream_t stream) {
#if FP4_TYPE_SUPPORTED
using namespace quantize_transpose_tuned_kernel;
using namespace ptx;
const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false;
const bool use_fast_math = quant_config ? quant_config->use_fast_math : false;
// If transposed output is allocated, return the transposed data
// Otherwise, it's not necesary to return the transposed data.
const bool return_transpose = output->has_columnwise_data();
checkCuDriverContext(stream);
CheckNoopTensor(*noop, "cast_noop");
CheckInputTensor(input, "input");
CheckOutputTensor(*output, "output", false);
NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data.");
NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated.");
NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type.");
NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated");
if (return_transpose) {
NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype),
"Transposed output must have FP4 type.");
NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr,
"Transposed scaling tensor must be allocated");
}
const size_t rows = input.flat_first_dim();
const size_t cols = input.flat_last_dim();
NVTE_CHECK(rows % 32 == 0,
"Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA
NVTE_CHECK(cols % 32 == 0,
"Number of tensor cols must be a multiple of 32"); // 16B alignment for TMA
const int blocks_Y = DIVUP(rows, static_cast<size_t>(TunableConfig::CHUNK_DIM_Y));
const int blocks_X = DIVUP(cols, static_cast<size_t>(TunableConfig::CHUNK_DIM_X));
const dim3 grid(blocks_X, blocks_Y);
const int block_size = THREADS_NUM;
const size_t scale_stride = output->scale_inv.shape[1];
const size_t scale_stride_transpose =
return_transpose ? output->columnwise_scale_inv.shape[1] : 0;
nvfp4_scale_t *const scales_ptr = reinterpret_cast<nvfp4_scale_t *>(output->scale_inv.dptr);
nvfp4_scale_t *const scales_transpose_ptr =
reinterpret_cast<nvfp4_scale_t *>(output->columnwise_scale_inv.dptr);
const float *noop_ptr = reinterpret_cast<const float *>(noop->data.dptr);
const float *const amax_rowwise_ptr = reinterpret_cast<const float *>(output->amax.dptr);
const float *const amax_colwise_ptr =
reinterpret_cast<const float *>(output->columnwise_amax.dptr);
const NVTETensor rng_state_tensor = (quant_config != nullptr) ? quant_config->rng_state : nullptr;
const size_t *rng_state = nullptr;
if (rng_state_tensor != nullptr) {
Tensor &rng_state_te_tensor = *convertNVTETensor(rng_state_tensor);
NVTE_CHECK(rng_state_te_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_te_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_te_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_te_tensor.data.dptr);
}
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_output{};
alignas(64) CUtensorMap tensor_map_output_transpose{};
create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
sizeof(IType) * 8);
create_2D_tensor_map(tensor_map_output, output->data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0,
4);
if (return_transpose) {
create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows,
BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4);
}
constexpr int buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr int buff_elems_total_in = BUFFS_NUM_IN * buff_elems;
constexpr int buff_size_aligned_in =
DIVUP_TO_MULTIPLE(buff_elems_total_in * sizeof(IType), TMA_SHMEM_ALIGNMENT);
constexpr int buff_size_aligned_out =
DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT * BUFF_OUT_SIZE, TMA_SHMEM_ALIGNMENT);
constexpr int buff_size_aligned_out_t =
DIVUP_TO_MULTIPLE(BUFFS_NUM_OUT_TR * BUFF_OUT_TR_SIZE, TMA_SHMEM_ALIGNMENT);
constexpr int buff_size_scales = DIVUP_TO_MULTIPLE(
TunableConfig::CHUNK_DIM_Y * SCALES_PER_CHUNK_X * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT);
constexpr int buff_size_scales_transpose = DIVUP_TO_MULTIPLE(
TunableConfig::CHUNK_DIM_X * SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t), TMA_SHMEM_ALIGNMENT);
const int in_mem = buff_size_aligned_in;
const int out_data_mem = buff_size_aligned_out;
const int out_data_transpose_mem = return_transpose ? buff_size_aligned_out_t : 0;
const int out_scales_mem = buff_size_scales;
const int out_scales_transpose_mem = return_transpose ? buff_size_scales_transpose : 0;
const int out_mem = out_data_mem + out_data_transpose_mem;
const int dshmem_size =
in_mem + out_mem + out_scales_transpose_mem + out_scales_mem + TMA_SHMEM_ALIGNMENT;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, USE_STOCHASTIC_ROUNDING,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_fast_math, USE_FAST_MATH,
TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, {
auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel<USE_STOCHASTIC_ROUNDING,
USE_FAST_MATH, RETURN_TRANSPOSE>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr,
scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols,
scale_stride, scale_stride_transpose, rng_state);
});););
#else
NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION);
#endif // FP4_TYPE_SUPPORTED
}
} // namespace nvfp4
} // namespace dispatch
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_QUANTIZE_TRANSPOSE_NVFP4_TUNED_1D_CUH_
......@@ -8,7 +8,6 @@
#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>
#include <map>
#include <memory>
......@@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
ctx->grid_row_major.get(), ctx->d_desc.get()));
const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}
......@@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N;
const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a,
sizeof trans_a));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b,
sizeof trans_b));
cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr,
sizeof algo_attr));
const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32;
if (is_fp8_dtype(a->dtype())) {
NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER,
&a->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(b->dtype())) {
NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER,
&b->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(d->dtype())) {
NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER,
&d->scale.dptr, sizeof(void*)));
if (d->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER,
&d->amax.dptr, sizeof(void*)));
}
......@@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t epilogue{};
size_t size_read{};
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue, &size_read));
NVTE_CHECK(size_read == sizeof epilogue);
......@@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad});
it != flags_to_epilogue.end()) {
epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}
if (bias && bias->data.dptr) {
cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type,
sizeof bias_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr,
sizeof bias->data.dptr));
}
if (pre_act_out && pre_act_out->data.dptr) {
cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof aux_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER,
&pre_act_out->data.dptr, sizeof pre_act_out->data.dptr));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd,
sizeof ldd));
if (is_fp8_dtype(pre_act_out->dtype())) {
NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE,
&scale_mode, sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER,
&pre_act_out->scale.dptr, sizeof(void*)));
if (pre_act_out->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER,
&pre_act_out->amax.dptr, sizeof(void*)));
}
......@@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
}
if (comm_sm_count) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT,
&comm_sm_count, sizeof comm_sm_count));
}
NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream));
NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream));
size_t wrksp_size_device{};
size_t wrksp_size_host{};
......@@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
std::vector<uint8_t> workspace_host(wrksp_size_host);
if (ctx->workspace_size < wrksp_size_device) {
nvshmem_free(ctx->workspace);
ctx->workspace = nvshmem_malloc(wrksp_size_device);
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
NVTE_CHECK_CUBLASMP(
cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device));
NVTE_CHECK_CUBLASMP(
cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device));
ctx->workspace_size = wrksp_size_device;
}
......@@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
nvshmemx_sync_all_on_stream(ctx->stream.get());
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
delete ctx;
}
......
......@@ -321,6 +321,9 @@ struct GroupedTensor {
SimpleTensor columnwise_amax;
SimpleTensor scale; // for FP8-DS only
NVTEScalingMode scaling_mode;
size_t num_tensors;
// Shape information (OPTIONAL - empty if dimension is uniform across all tensors)
// first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim)
// last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim)
......@@ -338,10 +341,14 @@ struct GroupedTensor {
// Always 2D with positive dimensions
NVTEShape logical_shape;
NVTEScalingMode scaling_mode;
size_t num_tensors;
NVTEGroupedTensor nvte_tensor;
/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool with_gemm_swizzled_scales = false;
GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(),
columnwise_data(),
......@@ -350,12 +357,12 @@ struct GroupedTensor {
amax(),
columnwise_amax(),
scale(),
scaling_mode(scaling_mode),
num_tensors(num_tensors),
first_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
last_dims(nullptr, std::vector<size_t>{0}, DType::kInt64),
tensor_offsets(nullptr, std::vector<size_t>{0}, DType::kInt64),
logical_shape(nvte_make_shape(nullptr, 1)),
scaling_mode(scaling_mode),
nvte_tensor(0) {}
explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; }
......@@ -408,6 +415,7 @@ struct GroupedTensor {
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
with_gemm_swizzled_scales = false;
}
};
......
......@@ -206,7 +206,7 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
float dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, int64_t window_size_left,
int64_t window_size_right, bool return_max_logit, bool cuda_graph) {
int64_t window_size_right, bool return_max_logit, bool cuda_graph, bool deterministic) {
using namespace transformer_engine;
NVTE_Fused_Attn_Backend backend = NVTE_Fused_Attn_Backend::NVTE_No_Backend;
const int device_id = cuda::current_device();
......@@ -406,9 +406,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
(window_size_right == -1 || window_size_right == 0)) ||
// 9.2: SWA (left, 0) + top-left diagonal + {bshd, sbhd}
(cudnn_runtime_version >= 90200 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
((window_size_left == -1 && window_size_right == -1 &&
attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK) ||
((window_size_left == -1 || window_size_left >= 0) && window_size_right == 0 &&
(attn_mask_type == NVTE_Mask_Type::NVTE_NO_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
max_seqlen_q == max_seqlen_kv)) &&
max_seqlen_q <= max_seqlen_kv && dropout == 0.0 &&
......@@ -418,12 +420,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.6: SWA (left, 0) + top-left/bottom-right diagonal + {bshd, sbhd, thd}
(cudnn_runtime_version >= 90600 &&
((window_size_left == -1 && (window_size_right == -1 || window_size_right == 0)) ||
((window_size_left >= 0 || window_size_left == -1) && window_size_right == 0 &&
((window_size_left >= 0 || window_size_left == -1) &&
(window_size_right >= 0 || window_size_right == -1) &&
((attn_mask_type == NVTE_Mask_Type::NVTE_CAUSAL_BOTTOM_RIGHT_MASK &&
// TODO(cyang): fix bug for BRCM + cross-attention on sm100
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
cudnn_runtime_version <= 90700) ||
cudnn_runtime_version > 90700)))) ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK ||
attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK ||
(attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK &&
(sm_arch_ < 100 || (sm_arch_ >= 100 && ((max_seqlen_q == max_seqlen_kv &&
......@@ -440,7 +444,16 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// 9.13.1+: vanilla, off-by-one, learnable
(cudnn_runtime_version >= 91301 ||
(cudnn_runtime_version < 91301 &&
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX))) {
softmax_type == NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX)) &&
// determinism on Blackwell
// pre-9.18.1: fwd: deterministic; bwd: non-deterministic
// 9.18.1+: fwd: deterministic; bwd: non-deterministic/deterministic
(sm_arch_ < 100 ||
(sm_arch_ >= 100 && (!is_training ||
(is_training && !deterministic &&
(dropout == 0.0 || bias_type == NVTE_Bias_Type::NVTE_NO_BIAS)) ||
(is_training && deterministic && cudnn_runtime_version >= 91801 &&
dropout == 0.0 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS))))) {
flag_arb = true;
}
if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) {
......@@ -506,16 +519,14 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend(
// NVTE fused attention FWD with packed QKV
// DEPRECATED: This API is deprecated.
// Please use nvte_fused_attn_fwd with separate Q, K, V tensors instead.
void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
const NVTETensor SoftmaxOffset, NVTETensor S, NVTETensor O,
NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state,
size_t max_seqlen, bool is_training, bool return_max_logit,
bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right,
NVTETensor workspace, cudaStream_t stream) {
void nvte_fused_attn_fwd_qkvpacked(
const NVTETensor QKV, const NVTETensor Bias, const NVTETensor SoftmaxOffset, NVTETensor S,
NVTETensor O, NVTETensorPack *Aux_CTX_Tensors, const NVTETensor cu_seqlens,
const NVTETensor cu_seqlens_padded, const NVTETensor rng_state, size_t max_seqlen,
bool is_training, bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
NVTE_Softmax_Type softmax_type, int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_qkvpacked);
using namespace transformer_engine;
......@@ -553,7 +564,7 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h, h, max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, return_max_logit,
cuda_graph);
cuda_graph, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -589,13 +600,14 @@ void nvte_fused_attn_fwd_qkvpacked(const NVTETensor QKV, const NVTETensor Bias,
fused_attn_arbitrary_seqlen_fwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, 0, 0, 0, 0, 0, 0, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, &Q_view, &K_view, &V_view, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr, input_rng_state,
wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, &Q_view, &K_view, &V_view,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens,
input_cu_seqlens, input_cu_seqlens_padded, input_cu_seqlens_padded, nullptr, nullptr,
input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"\n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
......@@ -629,8 +641,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTETensor dSoftmaxOffset, const NVTETensor cu_seqlens, const NVTETensor cu_seqlens_padded,
size_t max_seqlen, float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_qkvpacked);
using namespace transformer_engine;
......@@ -669,7 +681,8 @@ void nvte_fused_attn_bwd_qkvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, QKV_type, QKV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h, h,
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph);
max_seqlen, max_seqlen, d, d, window_size_left, window_size_right, false, cuda_graph,
deterministic);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -725,10 +738,11 @@ void nvte_fused_attn_bwd_qkvpacked(
fused_attn_arbitrary_seqlen_bwd(
b, h, h, max_seqlen, max_seqlen, d, d, t, t, attn_scale, dropout, qkv_layout, bias_type,
attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic, &Q_view,
&K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S, &dQ_view,
&dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens,
input_cu_seqlens_padded, input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
attn_mask_type, softmax_type, window_size_left, window_size_right, bottom_right_diagonal,
deterministic, &Q_view, &K_view, &V_view, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, &dQ_view, &dK_view, &dV_view, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens, input_cu_seqlens, input_cu_seqlens_padded,
input_cu_seqlens_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.0 is required for BF16/FP16 fused attention "
......@@ -779,7 +793,8 @@ void nvte_fused_attn_fwd_kvpacked(
size_t max_seqlen_kv, bool is_training, bool return_max_logit, bool cuda_graph,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_right, bool bottom_right_diagonal, NVTETensor workspace,
cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -855,7 +870,7 @@ void nvte_fused_attn_fwd_kvpacked(
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return_max_logit, cuda_graph, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -891,13 +906,14 @@ void nvte_fused_attn_fwd_kvpacked(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, &K_view, &V_view, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, input_Q, &K_view, &V_view,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
"cuDNN 8.9.3 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"\n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
......@@ -933,8 +949,8 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTETensor cu_seqlens_kv_padded, size_t max_seqlen_q, size_t max_seqlen_kv,
float attn_scale, float dropout, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type, int64_t window_size_left,
int64_t window_size_right, bool deterministic, bool cuda_graph, NVTETensor workspace,
cudaStream_t stream) {
int64_t window_size_right, bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd_kvpacked);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -982,10 +998,10 @@ void nvte_fused_attn_bwd_kvpacked(
const NVTEDType Q_type = static_cast<NVTEDType>(input_Q->data.dtype);
const NVTEDType KV_type = static_cast<NVTEDType>(input_KV->data.dtype);
NVTE_Fused_Attn_Backend fused_attention_backend =
nvte_get_fused_attn_backend(true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type,
softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d,
d, window_size_left, window_size_right, false, cuda_graph);
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d, d, window_size_left, window_size_right, false,
cuda_graph, deterministic);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -1040,11 +1056,11 @@ void nvte_fused_attn_bwd_kvpacked(
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d, d, t_q, t_kv, attn_scale, dropout, qkv_layout,
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right, deterministic,
input_Q, &K_view, &V_view, input_O, input_dO, input_Bias, input_SoftmaxOffset, output_S,
output_dQ, &dK_view, &dV_view, output_dBias, output_dSoftmaxOffset, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_rng_state,
wkspace, stream, handle);
bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
bottom_right_diagonal, deterministic, input_Q, &K_view, &V_view, input_O, input_dO,
input_Bias, input_SoftmaxOffset, output_S, output_dQ, &dK_view, &dV_view, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
const char *err_msg =
"cuDNN 8.9.3 is required for BF16/FP16 fused attention "
......@@ -1094,8 +1110,8 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
bool return_max_logit, bool cuda_graph, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, NVTETensor workspace,
cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_fwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -1166,7 +1182,7 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit, cuda_graph);
return_max_logit, cuda_graph, false);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -1183,13 +1199,14 @@ void nvte_fused_attn_fwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, num_pages_k, num_pages_v,
page_size_k, page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, is_training,
return_max_logit, attn_scale, dropout, qkv_layout, bias_type, attn_mask_type, softmax_type,
window_size_left, window_size_right, input_Q, input_K, input_V, input_Bias,
input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q, input_cu_seqlens_kv,
input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded, input_page_table_k,
input_page_table_v, input_rng_state, wkspace, stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, input_Q, input_K, input_V,
input_Bias, input_SoftmaxOffset, output_O, Aux_CTX_Tensors, input_cu_seqlens_q,
input_cu_seqlens_kv, input_cu_seqlens_q_padded, input_cu_seqlens_kv_padded,
input_page_table_k, input_page_table_v, input_rng_state, wkspace, stream, handle);
#else
NVTE_ERROR(
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. \n");
"cuDNN 8.9.0 is required for BF16/FP16 fused attention with arbitrary sequence length. "
"\n");
#endif
} else if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_FP8) {
#if (CUDNN_VERSION >= 8900)
......@@ -1215,8 +1232,9 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
size_t max_seqlen_kv, float attn_scale, float dropout,
NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type,
NVTE_Mask_Type attn_mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic,
bool cuda_graph, NVTETensor workspace, cudaStream_t stream) {
int64_t window_size_left, int64_t window_size_right,
bool bottom_right_diagonal, bool deterministic, bool cuda_graph,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_flash_attn_bwd);
using namespace transformer_engine;
const Tensor *input_cu_seqlens_q = convertNVTETensorCheck(cu_seqlens_q);
......@@ -1262,7 +1280,7 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
true, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, h_q,
h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, false,
cuda_graph);
cuda_graph, deterministic);
if (fused_attention_backend == NVTE_Fused_Attn_Backend::NVTE_F16_max512_seqlen) {
#if (CUDNN_VERSION >= 8901)
......@@ -1289,8 +1307,8 @@ void nvte_fused_attn_bwd(const NVTETensor Q, const NVTETensor K, const NVTETenso
fused_attn_arbitrary_seqlen_bwd(
b, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, t_q, t_kv, attn_scale, dropout,
qkv_layout, bias_type, attn_mask_type, softmax_type, window_size_left, window_size_right,
deterministic, input_Q, input_K, input_V, input_O, input_dO, input_Bias,
input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
bottom_right_diagonal, deterministic, input_Q, input_K, input_V, input_O, input_dO,
input_Bias, input_SoftmaxOffset, output_S, output_dQ, output_dK, output_dV, output_dBias,
output_dSoftmaxOffset, input_cu_seqlens_q, input_cu_seqlens_kv, input_cu_seqlens_q_padded,
input_cu_seqlens_kv_padded, input_rng_state, wkspace, stream, handle);
#else
......
......@@ -55,10 +55,10 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
int64_t max_pages_per_seq_v, int64_t bias_b, int64_t bias_h, bool is_training,
bool return_max_logit, float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, void *devPtrQ, void *devPtrK,
void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1, void *devPtrS2,
void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset, void *devPtrCuSeqlensQ,
void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal, void *devPtrQ,
void *devPtrK, void *devPtrV, void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrS1,
void *devPtrS2, void *devPtrO, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrPageTableK, void *devPtrPageTableV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
......@@ -75,6 +75,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
bottom_right_diagonal = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (is_training && dropout_probability != 0.0f);
......@@ -129,6 +130,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
softmax_type,
window_size_left,
window_size_right,
bottom_right_diagonal,
true,
tensorType,
cudnn_frontend::DataType_t::NOT_SET,
......@@ -248,15 +250,21 @@ void fused_attn_arbitrary_seqlen_fwd_impl(
fe::graph::SDPA_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_attributes()
.set_name("flash_attention")
.set_is_inference(false)
.set_generate_stats(generate_stats)
.set_causal_mask(is_causal)
.set_causal_mask_bottom_right(is_bottom_right)
.set_attn_scale(attn_scale);
fe::DiagonalAlignment_t const &diagonal_alignment =
bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT
: fe::DiagonalAlignment_t::TOP_LEFT;
sdpa_options.set_diagonal_alignment(diagonal_alignment);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
sdpa_options.set_diagonal_band_right_bound(window_size_right);
}
sdpa_options.set_alibi_mask(is_alibi);
......@@ -542,13 +550,14 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
int64_t max_b, int64_t max_t_q, int64_t max_t_kv, int64_t bias_b, int64_t bias_h,
float scaling_factor, float dropout_probability, NVTE_QKV_Layout layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, void *devPtrQ,
void *devPtrKTranspose, void *devPtrVTranspose, void *devPtrO, void *devPtrSoftmaxStats,
void *devPtrBias, void *devPtrSoftmaxOffset, void *devPtrdQ, void *devPtrdK, void *devPtrdV,
void *devPtrdO, void *devPtrdBias, void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed,
void *devPtrDropoutOffset, void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV,
void *devPtrSeqOffsetsQ, void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType,
void *workspace, size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, void *devPtrQ, void *devPtrKTranspose, void *devPtrVTranspose,
void *devPtrO, void *devPtrSoftmaxStats, void *devPtrBias, void *devPtrSoftmaxOffset,
void *devPtrdQ, void *devPtrdK, void *devPtrdV, void *devPtrdO, void *devPtrdBias,
void *devPtrdSoftmaxOffset, void *devPtrDropoutSeed, void *devPtrDropoutOffset,
void *devPtrCuSeqlensQ, void *devPtrCuSeqlensKV, void *devPtrSeqOffsetsQ,
void *devPtrSeqOffsetsKV, cudnn_frontend::DataType_t tensorType, void *workspace,
size_t *workspace_size, cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
bool is_bias = (bias_type == NVTE_Bias_Type::NVTE_POST_SCALE_BIAS);
......@@ -563,6 +572,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
if (is_bottom_right && s_q == s_kv && !is_padding) {
is_causal = true;
is_bottom_right = false;
bottom_right_diagonal = false;
}
bool is_softmax_offset = (softmax_type != NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX);
bool is_dropout = (dropout_probability != 0.0f);
......@@ -621,6 +631,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
softmax_type,
window_size_left,
window_size_right,
bottom_right_diagonal,
deterministic,
tensorType,
cudnn_frontend::DataType_t::NOT_SET,
......@@ -781,9 +792,17 @@ void fused_attn_arbitrary_seqlen_bwd_impl(
sdpa_backward_options.set_max_total_seq_len_kv(s_kv);
}
fe::DiagonalAlignment_t const &diagonal_alignment =
bottom_right_diagonal ? fe::DiagonalAlignment_t::BOTTOM_RIGHT
: fe::DiagonalAlignment_t::TOP_LEFT;
sdpa_backward_options.set_diagonal_alignment(diagonal_alignment);
if (cudnn_runtime_version >= 90200 && window_size_left != -1) {
sdpa_backward_options.set_diagonal_band_left_bound(window_size_left + 1);
}
if (cudnn_runtime_version >= 90600 && window_size_right != -1) {
sdpa_backward_options.set_diagonal_band_right_bound(window_size_right);
}
if (cudnn_runtime_version >= 90000) {
sdpa_backward_options.set_deterministic_algorithm(deterministic);
......@@ -1044,8 +1063,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
......@@ -1180,11 +1199,11 @@ void fused_attn_arbitrary_seqlen_fwd(
max_batch_size, max_tokens_q, max_tokens_kv, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h, is_training,
return_max_logit, attn_scale, p_dropout, qkv_layout, bias_type, mask_type, softmax_type,
window_size_left, window_size_right, devPtrQ, devPtrK, devPtrV, devPtrBias,
devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed, devPtrDropoutOffset,
devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV, devPtrSeqOffsetsQ,
devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr, &workspace_size,
stream, handle);
window_size_left, window_size_right, bottom_right_diagonal, devPtrQ, devPtrK, devPtrV,
devPtrBias, devPtrSoftmaxOffset, devPtrS1, devPtrS2, devPtrO, devPtrDropoutSeed,
devPtrDropoutOffset, devPtrCuSeqlensQ, devPtrCuSeqlensKV, devPtrPageTableK, devPtrPageTableV,
devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type), workspace->data.dptr,
&workspace_size, stream, handle);
if (workspace_size > 0) {
if (workspace->data.dptr == nullptr) {
......@@ -1206,13 +1225,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle) {
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle) {
using namespace transformer_engine;
const auto QKV_type = input_Q->data.dtype;
void *devPtrQ = input_Q->data.dptr;
......@@ -1273,8 +1293,8 @@ void fused_attn_arbitrary_seqlen_bwd(
batch, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim_qk, head_dim_v,
max_batch_size, max_tokens_q, max_tokens_kv, bias_b, bias_h, attn_scale, p_dropout,
qkv_layout, bias_type, mask_type, softmax_type, window_size_left, window_size_right,
deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats, devPtrBias,
devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
bottom_right_diagonal, deterministic, devPtrQ, devPtrK, devPtrV, devPtrO, devPtrSoftmaxStats,
devPtrBias, devPtrSoftmaxOffset, devPtrdQ, devPtrdK, devPtrdV, devPtrdO, devPtrdBias,
devPtrdSoftmaxOffset, devPtrDropoutSeed, devPtrDropoutOffset, devPtrCuSeqlensQ,
devPtrCuSeqlensKV, devPtrSeqOffsetsQ, devPtrSeqOffsetsKV, get_cudnn_fe_dtype(QKV_type),
workspace->data.dptr, &workspace_size, stream, handle);
......
......@@ -25,8 +25,8 @@ void fused_attn_arbitrary_seqlen_fwd(
size_t page_size_v, size_t max_pages_per_seq_k, size_t max_pages_per_seq_v, bool is_training,
bool return_max_logit, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_O, NVTETensorPack *Aux_CTX_Tensors,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *page_table_k, const Tensor *page_table_v,
......@@ -37,13 +37,14 @@ void fused_attn_arbitrary_seqlen_bwd(
size_t max_seqlen_kv, size_t head_dim_qk, size_t head_dim_v, size_t num_tokens_q,
size_t num_tokens_kv, float attn_scale, float p_dropout, NVTE_QKV_Layout qkv_layout,
NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_Softmax_Type softmax_type,
int64_t window_size_left, int64_t window_size_right, bool deterministic, const Tensor *input_Q,
const Tensor *input_K, const Tensor *input_V, const Tensor *input_O, const Tensor *input_dO,
const Tensor *input_Bias, const Tensor *input_SoftmaxOffset, Tensor *output_S,
Tensor *output_dQ, Tensor *output_dK, Tensor *output_dV, Tensor *output_dBias,
Tensor *output_dSoftmaxOffset, const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv,
const Tensor *cu_seqlens_q_padded, const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state,
Tensor *workspace, cudaStream_t stream, cudnnHandle_t handle);
int64_t window_size_left, int64_t window_size_right, bool bottom_right_diagonal,
bool deterministic, const Tensor *input_Q, const Tensor *input_K, const Tensor *input_V,
const Tensor *input_O, const Tensor *input_dO, const Tensor *input_Bias,
const Tensor *input_SoftmaxOffset, Tensor *output_S, Tensor *output_dQ, Tensor *output_dK,
Tensor *output_dV, Tensor *output_dBias, Tensor *output_dSoftmaxOffset,
const Tensor *cu_seqlens_q, const Tensor *cu_seqlens_kv, const Tensor *cu_seqlens_q_padded,
const Tensor *cu_seqlens_kv_padded, const Tensor *rng_state, Tensor *workspace,
cudaStream_t stream, cudnnHandle_t handle);
#endif // CUDNN_VERSION >= 8900
} // namespace transformer_engine
......
......@@ -1707,6 +1707,7 @@ void fused_attn_fp8_fwd_impl_v1(
0,
0,
true,
true,
qkv_tensor_type,
o_tensor_type,
cudnn_frontend::DataType_t::NOT_SET,
......@@ -1809,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1(
fe::graph::SDPA_fp8_attributes sdpa_options;
sdpa_options = fe::graph::SDPA_fp8_attributes()
.set_name("sdpa_fp8")
.set_is_inference(false)
.set_generate_stats(true)
.set_causal_mask(is_causal)
.set_attn_scale(attn_scale);
......@@ -2035,6 +2036,7 @@ void fused_attn_fp8_bwd_impl_v1(
NVTE_Softmax_Type::NVTE_VANILLA_SOFTMAX,
0,
0,
true,
false,
qkv_tensor_type,
o_tensor_type,
......
......@@ -535,11 +535,13 @@ size_t get_max_batch_size(size_t batch_size) {
// batch size is expected to be 10s-100s
// b = 1, ..., 32 -> max_b = 32
// b = 33, ..., 512 -> max_b = next power of 2
// otherwise -> max_b = b
// b = 513, ... -> max_b = increment by 512
if (log2_b <= 5) {
max_b = 32;
} else if (log2_b <= 9) {
max_b = pow(2, log2_b);
} else {
max_b = (batch_size + 511) / 512 * 512;
}
return max_b;
}
......
......@@ -111,6 +111,7 @@ struct FADescriptor_v1 {
NVTE_Softmax_Type softmax_type;
std::int64_t window_size_left;
std::int64_t window_size_right;
bool bottom_right_diagonal;
bool deterministic;
cudnn_frontend::DataType_t qkv_tensor_type;
cudnn_frontend::DataType_t o_tensor_type;
......@@ -122,15 +123,16 @@ struct FADescriptor_v1 {
return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
window_size_left, window_size_right, bottom_right_diagonal, deterministic,
bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type,
generate_max_sum_exp) <
std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal,
rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type,
rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
}
};
......
......@@ -126,3 +126,106 @@ void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
delete reinterpret_cast<transformer_engine::MatmulConfig *>(config);
}
}
NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() {
return new transformer_engine::GroupedMatmulConfig;
}
void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr, void *buf,
size_t size_in_bytes, size_t *size_written) {
// Write attribute size
NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes,
"Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr];
*size_written = attr_size;
// Return immediately if buffer is not provided
if (buf == nullptr) {
return;
}
// Check buffer size
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for grouped matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
// Write to buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
const auto &config_ = *reinterpret_cast<const transformer_engine::GroupedMatmulConfig *>(config);
switch (attr) {
case kNVTEGroupedMatmulConfigAvgM: {
int64_t val = config_.avg_m.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigAvgN: {
int64_t val = config_.avg_n.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigAvgK: {
int64_t val = config_.avg_k.value_or(0);
std::memcpy(buf, &val, attr_size);
break;
}
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(buf, &config_.sm_count, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config,
NVTEGroupedMatmulConfigAttribute attr,
const void *buf, size_t size_in_bytes) {
// Check attribute and buffer
NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes,
"Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr];
NVTE_CHECK(size_in_bytes >= attr_size,
"Buffer is too small for grouped matmul config attribute "
"(attribute ",
static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
" bytes)");
NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");
// Read from buffer
NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)");
auto &config_ = *reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
switch (attr) {
case kNVTEGroupedMatmulConfigAvgM: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_m = val;
break;
}
case kNVTEGroupedMatmulConfigAvgN: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_n = val;
break;
}
case kNVTEGroupedMatmulConfigAvgK: {
int64_t val;
std::memcpy(&val, buf, attr_size);
config_.avg_k = val;
break;
}
case kNVTEGroupedMatmulConfigSMCount:
std::memcpy(&config_.sm_count, buf, attr_size);
break;
default:
NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
}
}
void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) {
if (config != nullptr) {
delete reinterpret_cast<transformer_engine::GroupedMatmulConfig *>(config);
}
}
......@@ -9,6 +9,9 @@
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include <optional>
namespace transformer_engine {
struct MatmulConfig {
......@@ -31,6 +34,22 @@ struct MatmulConfig {
};
};
struct GroupedMatmulConfig {
// Average dimension hints for cuBLASLt algorithm selection heuristics.
// nullopt means "not set" - compute automatically from tensor shapes.
std::optional<int64_t> avg_m;
std::optional<int64_t> avg_n;
std::optional<int64_t> avg_k;
// Number of streaming multiprocessors to use in GEMM kernel
int sm_count = 0;
// Note: API transfers the value type, not std::optional
static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type),
sizeof(decltype(avg_n)::value_type),
sizeof(decltype(avg_k)::value_type), sizeof(sm_count)};
};
} // namespace transformer_engine
#endif // TRANSFORMER_ENGINE_GEMM_CONFIG_H_
......@@ -311,13 +311,6 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
return ret;
}
/* cuBLAS version number at run-time */
size_t cublas_version() {
// Cache version to avoid cuBLAS logging overhead
static size_t version = cublasLtGetVersion();
return version;
}
} // namespace
#endif // __HIP_PLATFORM_AMD__
......@@ -518,8 +511,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (mxfp8_gemm) {
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800,
"MXFP8 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
transformer_engine::cuda::cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
......@@ -541,7 +535,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
// Workaround for heuristic cache bug in cublasLt. This separates the MXFP8 cache key from non-block scaling.
// CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE is unused for block scaling so it's safe to set.
if (cublas_version() <= 120803) {
if (transformer_engine::cuda::cublas_version() <= 120803) {
const int64_t dummy_a_vec_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_ALPHA_VECTOR_BATCH_STRIDE, &dummy_a_vec_stride,
......@@ -553,8 +547,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
#endif // CUBLAS_VERSION >= 120800
} else if (use_fp4) { // NVFP4 GEMM
#if CUBLAS_VERSION >= 120800
NVTE_CHECK(cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ", cublas_version());
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120800,
"FP4 requires cuBLAS 12.8+, but run-time cuBLAS version is ",
transformer_engine::cuda::cublas_version());
// Check that scales are in expected format
NVTE_CHECK(inputA->with_gemm_swizzled_scales,
......@@ -589,9 +584,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
(inputB->scaling_mode == NVTE_BLOCK_SCALING_1D ||
inputB->scaling_mode == NVTE_BLOCK_SCALING_2D)) {
#if CUBLAS_VERSION >= 120900
NVTE_CHECK(cublas_version() >= 120900,
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120900,
"FP8 block scaling requires cuBLAS 12.9+, but run-time cuBLAS version is ",
cublas_version());
transformer_engine::cuda::cublas_version());
// Check that matrix formats are valid
NVTE_CHECK((!(inputA->scaling_mode == NVTE_BLOCK_SCALING_2D &&
......@@ -624,7 +619,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
}
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
if (transformer_engine::cuda::cublas_version() >= 120800) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
CUBLASLT_MATMUL_DESC_A_SCALE_MODE,
&scaling_mode_a, sizeof(scaling_mode_a)));
......@@ -641,7 +636,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_AMAX_D_POINTER, &D_amax, sizeof(D_amax)));
#if CUBLAS_VERSION >= 120800
if (cublas_version() >= 120800) {
if (transformer_engine::cuda::cublas_version() >= 120800) {
// NOTE: In all current cases where FP8 output is supported, the input is
// scaled identically to the output.
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(operationDesc,
......@@ -725,12 +720,14 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ",
CUBLAS_VERSION);
#else
NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000,
NVTE_CHECK(transformer_engine::cuda::cudart_version() >= 12020 &&
transformer_engine::cuda::cudart_version() < 13000,
"Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ",
cuda::cudart_version());
NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000,
transformer_engine::cuda::cudart_version());
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 120205 &&
transformer_engine::cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
transformer_engine::cuda::cublas_version());
if (m_split == 0) m_split = 1;
if (n_split == 0) n_split = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
......@@ -1201,9 +1198,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor
"Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ",
transformer_engine::cuda::cudart_version());
NVTE_CHECK(
cublas_version() >= 120205 && cublas_version() < 130000,
transformer_engine::cuda::cublas_version() >= 120205 &&
transformer_engine::cuda::cublas_version() < 130000,
"Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ",
cublas_version());
transformer_engine::cuda::cublas_version());
#endif
#else
#define NVTE_CUBLAS_ATOMIC_GEMM_COMPILE 1
......
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cublasLt.h>
#include <cublas_v2.h>
#include <cuda.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>
#include <cstdint>
#include "../common.h"
#include "../util/cuda_runtime.h"
#include "../util/handle_manager.h"
#include "../util/logging.h"
#include "./config.h"
namespace {
inline void CreateCublasHandle(cublasLtHandle_t *handle) {
NVTE_CHECK_CUBLAS(cublasLtCreate(handle));
}
} // namespace
#if CUBLAS_VERSION >= 130200
namespace {
// Helper struct to pass per-tensor shape/offset info (pointer or uniform value)
struct TensorShapeInfo {
const int64_t *first_dims; // nullptr if uniform
const int64_t *last_dims; // nullptr if uniform
const int64_t *offsets; // nullptr if need to compute
int64_t uniform_first; // used if first_dims == nullptr
int64_t uniform_last; // used if last_dims == nullptr
// Create from GroupedTensor
static TensorShapeInfo from_tensor(const transformer_engine::GroupedTensor *t) {
const bool has_first = t->first_dims.has_data();
const bool has_last = t->last_dims.has_data();
// When per-tensor dims are not provided, we must be in the uniform-shape case.
NVTE_CHECK(has_first || t->all_same_first_dim(),
"GroupedTensor is missing first_dims for varying shapes");
NVTE_CHECK(has_last || t->all_same_last_dim(),
"GroupedTensor is missing last_dims for varying shapes");
const int64_t *first_ptr =
has_first ? static_cast<const int64_t *>(t->first_dims.dptr) : nullptr;
const int64_t *last_ptr = has_last ? static_cast<const int64_t *>(t->last_dims.dptr) : nullptr;
const int64_t uniform_first = has_first ? 0 : static_cast<int64_t>(t->get_common_first_dim());
const int64_t uniform_last = has_last ? 0 : static_cast<int64_t>(t->get_common_last_dim());
return {first_ptr, last_ptr,
t->tensor_offsets.has_data() ? static_cast<const int64_t *>(t->tensor_offsets.dptr)
: nullptr,
uniform_first, uniform_last};
}
// Create for C tensor (uses D's dimensions, only has offsets)
static TensorShapeInfo create_shape_info_for_C(const transformer_engine::GroupedTensor *C,
const transformer_engine::GroupedTensor *D) {
const bool has_first = D->first_dims.has_data();
const bool has_last = D->last_dims.has_data();
NVTE_CHECK(has_first || D->all_same_first_dim(),
"GroupedTensor D is missing first_dims for varying shapes");
NVTE_CHECK(has_last || D->all_same_last_dim(),
"GroupedTensor D is missing last_dims for varying shapes");
const int64_t *first_ptr =
has_first ? static_cast<const int64_t *>(D->first_dims.dptr) : nullptr;
const int64_t *last_ptr = has_last ? static_cast<const int64_t *>(D->last_dims.dptr) : nullptr;
const int64_t uniform_first = has_first ? 0 : static_cast<int64_t>(D->get_common_first_dim());
const int64_t uniform_last = has_last ? 0 : static_cast<int64_t>(D->get_common_last_dim());
return {first_ptr, last_ptr,
C->tensor_offsets.has_data() ? static_cast<const int64_t *>(C->tensor_offsets.dptr)
: nullptr,
uniform_first, uniform_last};
}
};
// Helper functions to compute average dimensions from logical_shape for heuristics
// These are hints for cuBLASLt algorithm selection, don't need to be exact
inline int64_t compute_avg_first_dim(const transformer_engine::GroupedTensor *t) {
// logical_shape[0] is either num_tensors*M (uniform) or sum_of_M (varying first)
// In both cases, dividing by num_tensors gives the average
return static_cast<int64_t>(t->logical_shape.data[0]) / static_cast<int64_t>(t->num_tensors);
}
inline int64_t compute_avg_last_dim(const transformer_engine::GroupedTensor *t) {
if (t->all_same_last_dim()) {
// logical_shape[1] is the common N
return static_cast<int64_t>(t->logical_shape.data[1]);
}
// When varying, logical_shape[1] should be sum of last dims if provided; otherwise fallback to avg via division.
return static_cast<int64_t>(t->logical_shape.data[1]) / static_cast<int64_t>(t->num_tensors);
}
// Workspace layout for grouped GEMM
struct GroupedGemmSetupWorkspace {
void **A_ptrs;
void **B_ptrs;
void **C_ptrs;
void **D_ptrs;
float **alpha_ptrs;
float **beta_ptrs;
// Storage dimensions for cuBLAS matrix layouts
int *a_rows;
int *a_cols;
int *b_rows;
int *b_cols;
int *d_rows; // M (first dim) - also used for C
int *d_cols; // N (last dim) - also used for C
// Initialize from workspace buffer
// Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned)
static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) {
GroupedGemmSetupWorkspace ws;
size_t offset = 0;
const size_t ptr_size = num_tensors * sizeof(void *);
const size_t int_size = num_tensors * sizeof(int);
// Pointer arrays first (all 8-byte aligned)
ws.A_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.B_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.C_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.D_ptrs = reinterpret_cast<void **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.alpha_ptrs = reinterpret_cast<float **>(setup_ws_ptr + offset);
offset += ptr_size;
ws.beta_ptrs = reinterpret_cast<float **>(setup_ws_ptr + offset);
offset += ptr_size;
// Int arrays for storage dimensions (4-byte aligned)
ws.a_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.a_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.b_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.b_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.d_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
ws.d_cols = reinterpret_cast<int *>(setup_ws_ptr + offset);
return ws;
}
// Calculate required size for setup workspace
static size_t required_setup_size(size_t num_tensors, size_t alignment) {
const size_t ptr_size = num_tensors * sizeof(void *);
const size_t int_size = num_tensors * sizeof(int);
// Layout: 6 ptr arrays, then 6 int arrays
size_t size = 6 * ptr_size + 6 * int_size;
size = ((size + alignment - 1) / alignment) * alignment;
return size;
}
};
// -----------------------------------------------------------------------------
// Helper routines to keep nvte_grouped_gemm readable
// -----------------------------------------------------------------------------
inline void validate_grouped_gemm_inputs(const transformer_engine::GroupedTensor *inputA,
const transformer_engine::GroupedTensor *inputB,
const transformer_engine::GroupedTensor *inputC,
const transformer_engine::GroupedTensor *outputD,
const transformer_engine::Tensor *alpha_tensor,
const transformer_engine::Tensor *beta_tensor) {
const size_t num_tensors = inputA->num_tensors;
NVTE_CHECK(num_tensors >= 1, "Grouped GEMM: number of tensors must be at least 1");
NVTE_CHECK(inputB->num_tensors == num_tensors,
"Grouped GEMM: A and B must have the same number of tensors");
// C can be NULL (will use D as C when beta=0)
if (inputC != nullptr) {
NVTE_CHECK(inputC->num_tensors == num_tensors,
"Grouped GEMM: A and C must have the same number of tensors");
}
NVTE_CHECK(outputD->num_tensors == num_tensors,
"Grouped GEMM: A and D must have the same number of tensors");
// Validate alpha/beta have per-matrix values
const size_t alpha_numel = alpha_tensor->data.numel();
const size_t beta_numel = beta_tensor->data.numel();
NVTE_CHECK(alpha_numel == num_tensors, "Grouped GEMM: alpha must have num_tensors (", num_tensors,
") elements, got ", alpha_numel);
NVTE_CHECK(beta_numel == num_tensors, "Grouped GEMM: beta must have num_tensors (", num_tensors,
") elements, got ", beta_numel);
auto is_fp8_or_16bit = [](transformer_engine::DType dtype) {
return dtype == transformer_engine::DType::kFloat8E4M3 ||
dtype == transformer_engine::DType::kFloat8E5M2 ||
dtype == transformer_engine::DType::kBFloat16 ||
dtype == transformer_engine::DType::kFloat16;
};
auto is_output_dtype = [](transformer_engine::DType dtype) {
return dtype == transformer_engine::DType::kBFloat16 ||
dtype == transformer_engine::DType::kFloat16 ||
dtype == transformer_engine::DType::kFloat32;
};
NVTE_CHECK(is_fp8_or_16bit(inputA->dtype()) && is_fp8_or_16bit(inputB->dtype()),
"Grouped GEMM inputs must be FP8, BF16, or FP16.");
// Only check C dtype if C is provided
if (inputC != nullptr) {
NVTE_CHECK(is_output_dtype(inputC->dtype()), "Grouped GEMM: C must be BF16, FP16, or FP32.");
}
NVTE_CHECK(is_output_dtype(outputD->dtype()), "Grouped GEMM: D must be BF16, FP16, or FP32.");
NVTE_CHECK(inputA->has_data() || inputA->has_columnwise_data(),
"Grouped GEMM: A tensor is missing both row-wise and column-wise data");
NVTE_CHECK(inputB->has_data() || inputB->has_columnwise_data(),
"Grouped GEMM: B tensor is missing both row-wise and column-wise data");
}
// Select row-wise vs column-wise storage and adjust transpose flag for grouped GEMM.
// Mirrors the non-grouped GEMM logic for FP8 layout handling (TN-only on Hopper) and
// fallback to column-wise data when row-wise is absent.
// Contains all information needed for GEMM setup - shape already accounts for storage layout.
struct GroupedOperandSelection {
TensorShapeInfo shape; // Shape info with dims already swapped for columnwise if needed
char *dptr = nullptr;
void *scale_inv = nullptr;
transformer_engine::DType dtype = transformer_engine::DType::kNumTypes;
bool trans = false;
};
// Helper to create TensorShapeInfo from a GroupedTensor, optionally swapping first/last dims.
// When swap_dims=true, first_dims and last_dims are swapped to account for columnwise storage.
// Note: tensor_offsets are the same for rowwise and columnwise data (same element count per tensor).
inline TensorShapeInfo create_shape_info(const transformer_engine::GroupedTensor *t,
bool swap_dims) {
const bool has_first = t->first_dims.has_data();
const bool has_last = t->last_dims.has_data();
NVTE_CHECK(has_first || t->all_same_first_dim(),
"GroupedTensor is missing first_dims for varying shapes");
NVTE_CHECK(has_last || t->all_same_last_dim(),
"GroupedTensor is missing last_dims for varying shapes");
const int64_t *first_ptr = has_first ? static_cast<const int64_t *>(t->first_dims.dptr) : nullptr;
const int64_t *last_ptr = has_last ? static_cast<const int64_t *>(t->last_dims.dptr) : nullptr;
const int64_t uniform_first = has_first ? 0 : static_cast<int64_t>(t->get_common_first_dim());
const int64_t uniform_last = has_last ? 0 : static_cast<int64_t>(t->get_common_last_dim());
const int64_t *offsets_ptr =
t->tensor_offsets.has_data() ? static_cast<const int64_t *>(t->tensor_offsets.dptr) : nullptr;
if (swap_dims) {
// Swap first/last to account for columnwise (transposed) storage
return {last_ptr, first_ptr, offsets_ptr, uniform_last, uniform_first};
}
return {first_ptr, last_ptr, offsets_ptr, uniform_first, uniform_last};
}
inline GroupedOperandSelection select_grouped_operand(const transformer_engine::GroupedTensor *t,
bool trans, bool is_A) {
using namespace transformer_engine;
const bool has_row = t->has_data();
const bool has_col = t->has_columnwise_data();
NVTE_CHECK(has_row || has_col,
"Grouped GEMM operand is missing both row-wise and column-wise data");
// Currently only unquantized data and tensor-scaled FP8 are supported.
const auto sm = t->scaling_mode;
NVTE_CHECK(sm == NVTE_DELAYED_TENSOR_SCALING,
"Grouped GEMM is only supported with unquantized data and tensor-scaled FP8 data");
const DType row_dtype = t->data.dtype;
const DType col_dtype = t->columnwise_data.dtype;
GroupedOperandSelection sel;
sel.trans = trans;
const DType rep_dtype = has_row ? row_dtype : col_dtype;
const bool is_fp8 = is_fp8_dtype(rep_dtype);
const bool non_tn_fp8_ok = nvte_is_non_tn_fp8_gemm_supported();
// Helper to select columnwise storage (swaps dims in shape)
auto use_columnwise = [&]() {
sel.dptr = static_cast<char *>(t->columnwise_data.dptr);
sel.scale_inv = t->columnwise_scale_inv.dptr;
sel.dtype = col_dtype;
sel.shape = create_shape_info(t, /*swap_dims=*/true);
};
// Helper to select row-wise storage
auto use_rowwise = [&]() {
sel.dptr = static_cast<char *>(t->data.dptr);
sel.scale_inv = t->scale_inv.dptr;
sel.dtype = row_dtype;
sel.shape = create_shape_info(t, /*swap_dims=*/false);
};
// Hopper-style TN-only FP8: force TN by switching layout and flipping transpose when needed.
if (is_fp8 && !non_tn_fp8_ok) {
if (is_A) {
if (!sel.trans) {
NVTE_CHECK(has_col, "Grouped GEMM: A is missing column-wise data needed for FP8 TN layout");
use_columnwise();
sel.trans = true; // using pre-transposed storage
return sel;
}
} else { // B
if (sel.trans) {
NVTE_CHECK(has_col, "Grouped GEMM: B is missing column-wise data needed for FP8 TN layout");
use_columnwise();
sel.trans = false; // using pre-transposed storage
return sel;
}
}
}
// If only column-wise data is available, mirror the transpose flag (pre-transposed storage).
if (!has_row && has_col) {
// On Hopper FP8, this would break TN requirement - should have been handled above
NVTE_CHECK(
!is_fp8 || non_tn_fp8_ok,
"Grouped GEMM: FP8 on Hopper requires row-wise data for this transpose configuration");
use_columnwise();
sel.trans = !trans; // flip transpose for pre-transposed storage
return sel;
}
// Default: use row-wise data
use_rowwise();
return sel;
}
inline void *validate_and_get_workspace_ptr(transformer_engine::Tensor *ws, size_t required_size,
const char *workspace_name) {
NVTE_CHECK(ws != nullptr, workspace_name, " tensor is null.");
const size_t provided_size = get_buffer_size_bytes(ws->data.numel(), ws->data.dtype);
NVTE_CHECK(provided_size >= required_size, "Grouped GEMM: Insufficient ", workspace_name,
". Required: ", required_size, " bytes, Available: ", provided_size, " bytes.");
return ws->data.dptr;
}
inline void init_matrix_layouts(cublasLtMatrixLayoutOpaque_t &descA,
cublasLtMatrixLayoutOpaque_t &descB,
cublasLtMatrixLayoutOpaque_t &descC,
cublasLtMatrixLayoutOpaque_t &descD,
const GroupedGemmSetupWorkspace &ws,
const GroupedOperandSelection &A_sel,
const GroupedOperandSelection &B_sel,
const transformer_engine::GroupedTensor *D, size_t num_tensors) {
const cudaDataType_t A_type = get_cuda_dtype(A_sel.dtype);
const cudaDataType_t B_type = get_cuda_dtype(B_sel.dtype);
const cudaDataType_t D_type = get_cuda_dtype(D->dtype());
// Storage dimensions computed by kernel, leading dimension = rows
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descA, A_type, num_tensors, ws.a_rows,
ws.a_cols, ws.a_rows));
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descB, B_type, num_tensors, ws.b_rows,
ws.b_cols, ws.b_rows));
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descC, D_type, num_tensors, ws.d_rows,
ws.d_cols, ws.d_rows));
NVTE_CHECK_CUBLAS(cublasLtGroupedMatrixLayoutInit(&descD, D_type, num_tensors, ws.d_rows,
ws.d_cols, ws.d_rows));
}
inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOperation_t op_A,
cublasOperation_t op_B) {
NVTE_CHECK_CUBLAS(cublasLtMatmulDescInit(&matmulDesc, CUBLAS_COMPUTE_32F, CUDA_R_32F));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSA, &op_A,
sizeof(op_A)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_TRANSB, &op_B,
sizeof(op_B)));
cublasLtPointerMode_t pointer_mode = CUBLASLT_POINTER_MODE_DEVICE;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_POINTER_MODE,
&pointer_mode, sizeof(pointer_mode)));
int64_t alphabeta_batch_stride = 1;
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc,
CUBLASLT_MATMUL_DESC_ALPHA_BATCH_STRIDE,
&alphabeta_batch_stride, sizeof(int64_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc,
CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE,
&alphabeta_batch_stride, sizeof(int64_t)));
}
inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc,
const GroupedOperandSelection &A_sel,
const GroupedOperandSelection &B_sel) {
const bool is_fp8_a = is_fp8_dtype(A_sel.dtype);
const bool is_fp8_b = is_fp8_dtype(B_sel.dtype);
if (!is_fp8_a && !is_fp8_b) return;
if (is_fp8_a) {
void *a_scale_inv = A_sel.scale_inv;
NVTE_CHECK(a_scale_inv != nullptr, "FP8 grouped GEMM: A scale_inv is required");
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
&matmulDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale_inv, sizeof(a_scale_inv)));
}
if (is_fp8_b) {
void *b_scale_inv = B_sel.scale_inv;
NVTE_CHECK(b_scale_inv != nullptr, "FP8 grouped GEMM: B scale_inv is required");
NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(
&matmulDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale_inv, sizeof(b_scale_inv)));
}
}
// Constants for grouped GEMM workspace (declared early for use in heuristics)
static constexpr size_t kGroupedGemmAlignment = 256;
static constexpr size_t kGroupedGemmCublasWorkspaceSize = 32ull * 1024 * 1024; // 32 MiB
inline cublasLtMatmulAlgo_t select_grouped_gemm_algo(cublasLtHandle_t handle,
cublasLtMatmulDescOpaque_t &matmulDesc,
cublasLtMatrixLayoutOpaque_t &descA,
cublasLtMatrixLayoutOpaque_t &descB,
cublasLtMatrixLayoutOpaque_t &descC,
cublasLtMatrixLayoutOpaque_t &descD,
int64_t avg_m, int64_t avg_n, int64_t avg_k) {
cublasLtMatmulPreferenceOpaque_t preference;
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceInit(&preference));
NVTE_CHECK_CUBLAS(
cublasLtMatmulPreferenceSetAttribute(&preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&kGroupedGemmCublasWorkspaceSize, sizeof(size_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_ROWS, &avg_m, sizeof(int64_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_GROUPED_DESC_D_AVERAGE_COLS, &avg_n, sizeof(int64_t)));
NVTE_CHECK_CUBLAS(cublasLtMatmulPreferenceSetAttribute(
&preference, CUBLASLT_MATMUL_PREF_GROUPED_AVERAGE_REDUCTION_DIM, &avg_k, sizeof(int64_t)));
cublasLtMatmulHeuristicResult_t heuristicResult;
int returnedResults = 0;
auto status = cublasLtMatmulAlgoGetHeuristic(handle, &matmulDesc, &descA, &descB, &descC, &descD,
&preference, 1, &heuristicResult, &returnedResults);
NVTE_CHECK(status != CUBLAS_STATUS_NOT_SUPPORTED,
"Unable to find suitable cuBLAS grouped GEMM algorithm");
NVTE_CHECK_CUBLAS(status);
NVTE_CHECK(returnedResults > 0, "No suitable algorithm found for grouped GEMM");
return heuristicResult.algo;
}
// Single kernel that sets up all GEMM parameters.
// Rationale: cuBLASLt grouped matmul API needs flat arrays of pointers and per-matrix dimensions,
// but NVTEGroupedTensor stores a single contiguous buffer + optional per-tensor offsets/shapes.
// We bridge the mismatch on GPU by computing per-group pointers and storage dims in one kernel.
__global__ void setup_grouped_gemm_kernel(
// Output arrays
void **A_ptrs, void **B_ptrs, void **C_ptrs, void **D_ptrs, int *a_rows, int *a_cols,
int *b_rows, int *b_cols, int *d_rows, int *d_cols, float **alpha_ptrs, float **beta_ptrs,
// Inputs
char *a_base, char *b_base, char *c_base, char *d_base, TensorShapeInfo A_meta,
TensorShapeInfo B_meta, TensorShapeInfo C_meta, TensorShapeInfo D_meta, size_t a_elem_size,
size_t b_elem_size, size_t c_elem_size, size_t d_elem_size, float *alpha_ptr, float *beta_ptr,
size_t num_tensors) {
size_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= num_tensors) return;
// Get dimensions for this tensor (from array or uniform value)
int64_t a_first = A_meta.first_dims ? A_meta.first_dims[idx] : A_meta.uniform_first;
int64_t a_last = A_meta.last_dims ? A_meta.last_dims[idx] : A_meta.uniform_last;
int64_t b_first = B_meta.first_dims ? B_meta.first_dims[idx] : B_meta.uniform_first;
int64_t b_last = B_meta.last_dims ? B_meta.last_dims[idx] : B_meta.uniform_last;
int64_t d_first = D_meta.first_dims ? D_meta.first_dims[idx] : D_meta.uniform_first;
int64_t d_last = D_meta.last_dims ? D_meta.last_dims[idx] : D_meta.uniform_last;
// Compute offsets (from array or compute from uniform dims)
int64_t a_offset =
A_meta.offsets ? A_meta.offsets[idx] : (idx * A_meta.uniform_first * A_meta.uniform_last);
int64_t b_offset =
B_meta.offsets ? B_meta.offsets[idx] : (idx * B_meta.uniform_first * B_meta.uniform_last);
int64_t c_offset =
C_meta.offsets ? C_meta.offsets[idx] : (idx * C_meta.uniform_first * C_meta.uniform_last);
int64_t d_offset =
D_meta.offsets ? D_meta.offsets[idx] : (idx * D_meta.uniform_first * D_meta.uniform_last);
// Compute data pointers
A_ptrs[idx] = a_base + a_offset * a_elem_size;
B_ptrs[idx] = b_base + b_offset * b_elem_size;
C_ptrs[idx] = c_base + c_offset * c_elem_size;
D_ptrs[idx] = d_base + d_offset * d_elem_size;
// Compute storage dimensions for cuBLAS matrix layouts.
// For INPUTS (A, B): Row-wise storage is seen as transposed column-major by cuBLAS,
// so rows=last, cols=first. For columnwise, dims are already swapped.
a_rows[idx] = static_cast<int>(a_last);
a_cols[idx] = static_cast<int>(a_first);
b_rows[idx] = static_cast<int>(b_last);
b_cols[idx] = static_cast<int>(b_first);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows[idx] = static_cast<int>(d_first);
d_cols[idx] = static_cast<int>(d_last);
// Fill alpha/beta pointers (per-matrix)
alpha_ptrs[idx] = alpha_ptr + idx;
beta_ptrs[idx] = beta_ptr + idx;
}
// Launch the setup kernel to populate workspace arrays
inline void launch_grouped_gemm_setup(
const GroupedGemmSetupWorkspace &ws, const GroupedOperandSelection &A_sel,
const GroupedOperandSelection &B_sel, const transformer_engine::GroupedTensor *C,
const transformer_engine::GroupedTensor *D, const transformer_engine::Tensor *alpha_tensor,
const transformer_engine::Tensor *beta_tensor, size_t num_tensors, cudaStream_t stream) {
// Use shape info from selection (already accounts for columnwise dimension swap)
TensorShapeInfo A_meta = A_sel.shape;
TensorShapeInfo B_meta = B_sel.shape;
TensorShapeInfo C_meta = TensorShapeInfo::create_shape_info_for_C(C, D);
TensorShapeInfo D_meta = TensorShapeInfo::from_tensor(D);
char *c_base = static_cast<char *>(C->data.dptr);
char *d_base = static_cast<char *>(D->data.dptr);
const size_t a_elem_size = transformer_engine::typeToSize(A_sel.dtype);
const size_t b_elem_size = transformer_engine::typeToSize(B_sel.dtype);
const size_t c_elem_size = transformer_engine::typeToSize(C->dtype());
const size_t d_elem_size = transformer_engine::typeToSize(D->dtype());
const int threads_per_block = 256;
const int num_blocks = (num_tensors + threads_per_block - 1) / threads_per_block;
setup_grouped_gemm_kernel<<<num_blocks, threads_per_block, 0, stream>>>(
ws.A_ptrs, ws.B_ptrs, ws.C_ptrs, ws.D_ptrs, ws.a_rows, ws.a_cols, ws.b_rows, ws.b_cols,
ws.d_rows, ws.d_cols, ws.alpha_ptrs, ws.beta_ptrs, A_sel.dptr, B_sel.dptr, c_base, d_base,
A_meta, B_meta, C_meta, D_meta, a_elem_size, b_elem_size, c_elem_size, d_elem_size,
static_cast<float *>(alpha_tensor->data.dptr), static_cast<float *>(beta_tensor->data.dptr),
num_tensors);
NVTE_CHECK_CUDA(cudaGetLastError());
}
inline size_t grouped_gemm_setup_workspace_size(size_t num_tensors) {
return GroupedGemmSetupWorkspace::required_setup_size(num_tensors, kGroupedGemmAlignment);
}
} // namespace
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream) {
NVTE_API_CALL(nvte_grouped_gemm);
using namespace transformer_engine;
// Grouped GEMM requires Blackwell (SM100) or newer and cuBLAS 13.2+
const int current_device = transformer_engine::cuda::current_device();
NVTE_CHECK(transformer_engine::cuda::sm_arch(current_device) >= 100,
"nvte_grouped_gemm requires Blackwell (SM100) or newer architecture.");
NVTE_CHECK(transformer_engine::cuda::cublas_version() >= 130200,
"nvte_grouped_gemm requires cuBLAS 13.2+, but run-time cuBLAS version is ",
transformer_engine::cuda::cublas_version());
// Convert to internal types
const GroupedTensor *inputA = convertNVTEGroupedTensorCheck(A);
const GroupedTensor *inputB = convertNVTEGroupedTensorCheck(B);
const GroupedTensor *inputC_raw = convertNVTEGroupedTensor(C); // Can be NULL
GroupedTensor *outputD = convertNVTEGroupedTensorCheck(D);
const Tensor *alpha_tensor = convertNVTETensorCheck(alpha);
const Tensor *beta_tensor = convertNVTETensorCheck(beta);
Tensor *wspace_setup = convertNVTETensor(workspace_setup);
Tensor *wspace_cublas = convertNVTETensor(workspace_cublas);
// Parse config (if provided)
GroupedMatmulConfig config_;
if (config != nullptr) {
config_ = *reinterpret_cast<GroupedMatmulConfig *>(config);
}
// Validate inputs and num_tensors
validate_grouped_gemm_inputs(inputA, inputB, inputC_raw, outputD, alpha_tensor, beta_tensor);
// If C is NULL, use D as C (valid when beta=0, cuBLAS won't read C data)
const GroupedTensor *inputC = (inputC_raw != nullptr) ? inputC_raw : outputD;
const size_t num_tensors = inputA->num_tensors;
// Select operand storage (row-wise vs column-wise) and adjust transpose flags to
// mirror the non-grouped GEMM logic for FP8 layout constraints.
const auto A_sel = select_grouped_operand(inputA, static_cast<bool>(transa), /*is_A=*/true);
const auto B_sel = select_grouped_operand(inputB, static_cast<bool>(transb), /*is_A=*/false);
// Workspaces: setup (pointer arrays) and cuBLAS
const size_t setup_workspace_size = grouped_gemm_setup_workspace_size(num_tensors);
const size_t cublas_workspace_size = kGroupedGemmCublasWorkspaceSize;
void *setup_workspace_ptr = validate_and_get_workspace_ptr(wspace_setup, setup_workspace_size,
"Grouped GEMM setup workspace");
void *cublas_workspace_ptr = validate_and_get_workspace_ptr(wspace_cublas, cublas_workspace_size,
"Grouped GEMM cuBLAS workspace");
auto setup_workspace = GroupedGemmSetupWorkspace::from_buffers(
static_cast<char *>(setup_workspace_ptr), num_tensors);
launch_grouped_gemm_setup(setup_workspace, A_sel, B_sel, inputC, outputD, alpha_tensor,
beta_tensor, num_tensors, stream);
// Get cuBLAS handle
using cublasHandleManager = detail::HandleManager<cublasLtHandle_t, CreateCublasHandle>;
cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle();
// Setup cuBLAS operations
cublasOperation_t op_A = A_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N;
cublasOperation_t op_B = B_sel.trans ? CUBLAS_OP_T : CUBLAS_OP_N;
// Create grouped matrix layouts
cublasLtMatrixLayoutOpaque_t descA, descB, descC, descD;
init_matrix_layouts(descA, descB, descC, descD, setup_workspace, A_sel, B_sel, outputD,
num_tensors);
// Create matmul descriptor
cublasLtMatmulDescOpaque_t matmulDesc;
init_matmul_desc(matmulDesc, op_A, op_B);
set_fp8_scale_pointers(matmulDesc, A_sel, B_sel);
// Compute average dimensions for heuristics
// K dimension: if transa, K is A's first dim; if not, K is A's last dim
// Use original inputA and transa for heuristics (not modified A_sel.trans)
int64_t avg_m_val = config_.avg_m.value_or(compute_avg_first_dim(outputD));
int64_t avg_n_val = config_.avg_n.value_or(compute_avg_last_dim(outputD));
int64_t avg_k_val =
config_.avg_k.value_or(transa ? compute_avg_first_dim(inputA) : compute_avg_last_dim(inputA));
// Heuristic selection
cublasLtMatmulAlgo_t algo = select_grouped_gemm_algo(handle, matmulDesc, descA, descB, descC,
descD, avg_m_val, avg_n_val, avg_k_val);
// Execute the grouped GEMM
NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, &matmulDesc, setup_workspace.alpha_ptrs,
setup_workspace.A_ptrs, &descA, setup_workspace.B_ptrs, &descB,
setup_workspace.beta_ptrs, setup_workspace.C_ptrs, &descC,
setup_workspace.D_ptrs, &descD, &algo, cublas_workspace_ptr,
kGroupedGemmCublasWorkspaceSize, stream));
}
#else // CUBLAS_VERSION < 130200
void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb,
const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha,
const NVTETensor beta, NVTETensor workspace_setup,
NVTETensor workspace_cublas, NVTEGroupedMatmulConfig config,
cudaStream_t stream) {
NVTE_ERROR("nvte_grouped_gemm requires cuBLAS 13.2+, but compile-time cuBLAS version is ",
CUBLAS_VERSION, ". Please upgrade to CUDA 13.1 or newer.");
}
#endif // CUBLAS_VERSION >= 130200
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <transformer_engine/hadamard_transform.h>
#include <transformer_engine/multi_tensor.h>
#include <cuda/barrier>
#include "common/common.h"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "hadamard_transform_utils.cuh"
namespace transformer_engine {
namespace {
constexpr int kMaxTensorsPerKernel = 64;
constexpr int kThreadsPerWarp = 32;
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
__device__ __forceinline__ size_t get_current_tensor_id(
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t* const __restrict__ offsets_ptr) {
if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) {
const size_t current_row = current_offset / last_logical_dim;
const size_t rows_per_tensor = first_logical_dim / num_tensors;
return current_row / rows_per_tensor;
} else {
// upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors)
size_t low = 0;
size_t hi = num_tensors; // half-open [low, hi)
while (low < hi) {
const size_t mid = low + (hi - low) / 2;
const size_t mid_offset = static_cast<size_t>(offsets_ptr[mid]);
if (mid_offset <= current_offset) {
low = mid + 1;
} else {
hi = mid;
}
}
// low = first index where offsets[low] > current_offset (or low == num_tensors)
// id = low - 1, but need to evaluate if current_offset < offsets[0]
return (low == 0) ? 0 : (low - 1);
}
}
template <typename IType, int kHadamardDimension, int BUFF_DIM_Y, int BUFF_DIM_X,
bool kReturnPreRhtAmax, bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__device__ __forceinline__ void ComputeKernel(uint32_t b_frag_i[4], uint32_t b_frag_t[4],
IType* in_sh_ptr, uint32_t& local_pre_rht_amax_reg,
uint32_t& local_amax_reg,
uint32_t& local_amax_t_reg) {
uint32_t a_frag[4]; // A matrix fragment
uint32_t c_frag[4]; // Result fragment
int warp_id = threadIdx.x / kThreadsPerWarp;
int local_rank = (threadIdx.x % kThreadsPerWarp);
int ld_row_idx = local_rank % kHadamardDimension;
int ld_col_idx = local_rank / kHadamardDimension + warp_id * 2;
int swizzle_idx = swizzle_128B_atom_32B(ld_row_idx, ld_col_idx);
uint32_t temp_amax_reg;
uint32_t temp_amax_t_reg;
if (kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnIdentityAmax>(
a_frag[0], a_frag[1], a_frag[2], a_frag[3], b_frag_i[0], b_frag_i[1], b_frag_i[2],
b_frag_i[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_reg)
: "r"(local_amax_reg), "r"(temp_amax_reg));
}
if (kReturnTransposedAmax) {
// TODO(Frank): This is not efficient, since we could directly load the
// matrix in transposed layout.
if (!kReturnIdentityAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
matrix_transpose_m8_n8_b16_inplace(a_frag[0]);
matrix_transpose_m8_n8_b16_inplace(a_frag[1]);
matrix_transpose_m8_n8_b16_inplace(a_frag[2]);
matrix_transpose_m8_n8_b16_inplace(a_frag[3]);
mma_m16_n16_k16_b16_b16_b16_noacc<kReturnTransposedAmax>(
a_frag[0], a_frag[2], a_frag[1], a_frag[3], b_frag_t[0], b_frag_t[1], b_frag_t[2],
b_frag_t[3], c_frag[0], c_frag[1], c_frag[2], c_frag[3], temp_amax_t_reg);
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_amax_t_reg)
: "r"(local_amax_t_reg), "r"(temp_amax_t_reg));
}
if (kReturnPreRhtAmax) {
if (!kReturnIdentityAmax && !kReturnTransposedAmax) {
ldmatrix_x4_m8n8_shared_b16<false>(a_frag[0], a_frag[1], a_frag[2], a_frag[3],
reinterpret_cast<uint4*>(in_sh_ptr) + swizzle_idx);
}
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[1]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[2])
: "r"(a_frag[2]), "r"(a_frag[3]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(a_frag[0])
: "r"(a_frag[0]), "r"(a_frag[2]));
asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;\n\t"
: "=r"(local_pre_rht_amax_reg)
: "r"(a_frag[0]), "r"(local_pre_rht_amax_reg));
}
}
template <int kN>
__device__ __host__ constexpr int NextPowerOf2() {
static_assert(kN > 0, "kN must be > 0");
// Round up to the next power of 2 by counting leading zeros.
return 1 << (32 - __builtin_clz(kN - 1));
}
template <int kNumWarps, bool kReturnPreRhtAmax, bool kReturnIdentityAmax,
bool kReturnTransposedAmax>
__device__ __forceinline__ void ReduceMax(const float pre_rht_amax, const float identity_amax,
const float transpose_amax, float* staging_for_pre_rht,
float* staging_for_identity, float* staging_for_transpose,
float* output_pre_rht_amax_ptr,
float* output_identity_amax_ptr,
float* output_transpose_amax_ptr, const int warpid) {
// intra-warp reduction
constexpr int kWarpSize = 32;
int local_rank = threadIdx.x % 32;
float warp_pre_rht_amax = kReturnPreRhtAmax ? warp_reduce_max<kWarpSize>(pre_rht_amax) : 0.0f;
float warp_identity_amax = kReturnIdentityAmax ? warp_reduce_max<kWarpSize>(identity_amax) : 0.0f;
float warp_transpose_amax =
kReturnTransposedAmax ? warp_reduce_max<kWarpSize>(transpose_amax) : 0.0f;
// inter-warp reduction
if (threadIdx.x % 32 == 0) {
if (kReturnPreRhtAmax) {
staging_for_pre_rht[warpid] = warp_pre_rht_amax;
}
if (kReturnIdentityAmax) {
staging_for_identity[warpid] = warp_identity_amax;
}
if (kReturnTransposedAmax) {
staging_for_transpose[warpid] = warp_transpose_amax;
}
}
__syncthreads();
constexpr int kNumWarpsPow2 = NextPowerOf2<kNumWarps>();
if (warpid == 0) {
if (kReturnIdentityAmax) {
float identity_accum = local_rank < kNumWarps ? staging_for_identity[local_rank] : 0.0f;
identity_accum = warp_reduce_max<kNumWarpsPow2>(identity_accum);
if (local_rank == 0) {
atomicMaxFloat(output_identity_amax_ptr, identity_accum);
}
}
}
if (warpid == 1) {
if (kReturnTransposedAmax) {
float transpose_accum = local_rank < kNumWarps ? staging_for_transpose[local_rank] : 0.0f;
transpose_accum = warp_reduce_max<kNumWarpsPow2>(transpose_accum);
if (local_rank == 0) {
atomicMaxFloat(output_transpose_amax_ptr, transpose_accum);
}
}
}
if (warpid == 2) {
if (kReturnPreRhtAmax) {
float pre_rht_accum = local_rank < kNumWarps ? staging_for_pre_rht[local_rank] : 0.0f;
pre_rht_accum = warp_reduce_max<kNumWarpsPow2>(pre_rht_accum);
if (local_rank == 0) {
atomicMaxFloat(output_pre_rht_amax_ptr, pre_rht_accum);
}
}
}
}
__global__ void GraphSafeMultiZeroAmaxKernel(const size_t num_tensors, float* amax_rowwise_ptr,
float* amax_colwise_ptr) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
// Assign each thread a range for rowwise and colwise independently
if (amax_rowwise_ptr != nullptr) {
for (int i = tid; i < num_tensors; i += stride) {
amax_rowwise_ptr[i] = 0.f;
}
}
if (amax_colwise_ptr != nullptr) {
for (int i = tid; i < num_tensors; i += stride) {
amax_colwise_ptr[i] = 0.f;
}
}
}
__global__ void GraphSafeMultiAmaxMemcpyD2DKernelPreRHT(const size_t num_tensors,
float* amax_rowwise_ptr,
float* amax_colwise_ptr) {
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
if (amax_rowwise_ptr != nullptr && amax_colwise_ptr != nullptr) {
for (; tid < num_tensors; tid += stride) {
float* output_pre_rht_amax_ptr = amax_rowwise_ptr + tid;
float* output_transpose_amax_ptr = amax_colwise_ptr + tid;
*output_transpose_amax_ptr = *output_pre_rht_amax_ptr;
}
}
}
template <typename IType, int kHadamardDimension, int CHUNK_DIM_Y, int CHUNK_DIM_X, int BUFF_DIM_Y,
int BUFF_DIM_X, int THREADS_PER_CHUNK, int THREADS_PER_Y, bool kReturnPreRhtAmax,
bool kReturnIdentityAmax, bool kReturnTransposedAmax>
__global__ void GraphSafeGroupHadamardAmaxTmaKernel(
const __grid_constant__ CUtensorMap tensor_map_input, uint16_t random_sign_mask,
uint16_t random_sign_mask_t, const ShapeRepresentation shape_rep, const size_t num_tensors,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t* const __restrict__ offsets_ptr, const int64_t* const __restrict__ first_dims_ptr,
float* const __restrict__ amax_rowwise_ptr, float* const __restrict__ amax_colwise_ptr) {
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
float* output_pre_rht_amax_ptr;
float* output_identity_amax_ptr = nullptr;
float* output_transpose_amax_ptr;
// calculate the global offset to get tensor id
size_t global_offset = blockIdx.y * CHUNK_DIM_Y * last_logical_dim;
int tensor_id = get_current_tensor_id(shape_rep, num_tensors, global_offset, first_logical_dim,
last_logical_dim, offsets_ptr);
output_pre_rht_amax_ptr = static_cast<float*>(amax_rowwise_ptr) + tensor_id;
output_transpose_amax_ptr = static_cast<float*>(amax_colwise_ptr) + tensor_id;
static_assert(CHUNK_DIM_Y >= BUFF_DIM_Y && CHUNK_DIM_Y % BUFF_DIM_Y == 0);
static_assert(CHUNK_DIM_X >= BUFF_DIM_X && CHUNK_DIM_X % BUFF_DIM_X == 0);
constexpr size_t STAGES_Y = CHUNK_DIM_Y / BUFF_DIM_Y;
constexpr size_t STAGES_X = CHUNK_DIM_X / BUFF_DIM_X;
constexpr int kNumWarps = (THREADS_PER_CHUNK * THREADS_PER_Y) / kThreadsPerWarp;
const int input_block_offset_Y = blockIdx.y * CHUNK_DIM_Y;
const int input_block_offset_X = blockIdx.x * CHUNK_DIM_X;
extern __shared__ __align__(128) char dynamic_shmem[];
uintptr_t base_shmem_ptr = reinterpret_cast<uintptr_t>(dynamic_shmem);
// Manually align dynamic SHMEM per TMA requirements using padding
// __align__(128) Does not guarantee the pointer to be aligned!
uint8_t* dshmem = reinterpret_cast<uint8_t*>((base_shmem_ptr + 127) & ~127ULL);
// The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned
constexpr size_t in_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
IType* in_sh_0 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_sh_1 = reinterpret_cast<IType*>(dshmem);
dshmem += in_buff_size;
IType* in_shs[2] = {in_sh_0, in_sh_1};
constexpr int shmem_buff_size = BUFF_DIM_X * BUFF_DIM_Y * sizeof(IType);
const bool is_master_thread = (threadIdx.x == 0 && threadIdx.y == 0);
// Initialize shared memory barrier with the number of threads participating in the barrier.
#pragma nv_diag_suppress static_var_with_dynamic_init
uint64_t* mbar = reinterpret_cast<uint64_t*>(dshmem);
dshmem += sizeof(uint64_t) * (STAGES_X * STAGES_Y);
float* max_staging_identity = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_transpose = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
float* max_staging_pre_rht = reinterpret_cast<float*>(dshmem);
dshmem += sizeof(float) * kNumWarps;
initialize_barriers<STAGES_X * STAGES_Y, THREADS_PER_CHUNK * THREADS_PER_Y>(mbar,
is_master_thread);
copy_2d_to_shared(in_shs[0], reinterpret_cast<const void*>(&tensor_map_input),
input_block_offset_X, input_block_offset_Y, shmem_buff_size, &mbar[0],
is_master_thread);
uint32_t had_frag_i[4];
uint32_t had_frag_t[4];
get_hadamard_matrix_fragment<kReturnIdentityAmax, kReturnTransposedAmax, false, false>(
had_frag_i, random_sign_mask, had_frag_t, random_sign_mask_t);
float local_pre_rht_amax = 0.0;
float local_amax = 0.0;
float local_amax_t = 0.0;
uint32_t local_pre_rht_amax_reg = *reinterpret_cast<uint32_t*>(&local_pre_rht_amax);
uint32_t local_amax_reg = *reinterpret_cast<uint32_t*>(&local_amax);
uint32_t local_amax_t_reg = *reinterpret_cast<uint32_t*>(&local_amax_t);
for (int stage_y = 0; stage_y < STAGES_Y; ++stage_y) {
for (int stage_x = 0; stage_x < STAGES_X; ++stage_x) {
int stage = STAGES_X * stage_y + stage_x;
const int next_stage = stage + 1;
const int next_stage_x = stage_x + 1 == STAGES_X ? 0 : stage_x + 1;
const int next_stage_y = stage_x + 1 == STAGES_X ? stage_y + 1 : stage_y;
if (next_stage < STAGES_X * STAGES_Y) {
const int input_global_offset_Y = input_block_offset_Y + next_stage_y * BUFF_DIM_Y;
const int input_global_offset_X = input_block_offset_X + next_stage_x * BUFF_DIM_X;
copy_2d_to_shared(in_shs[next_stage % 2], // ping-pong
reinterpret_cast<const void*>(&tensor_map_input), input_global_offset_X,
input_global_offset_Y, shmem_buff_size, &mbar[next_stage],
is_master_thread);
}
ptx::fence_proxy_async_shared_cta();
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], 0);
const size_t compute_stage_x_num =
BUFF_DIM_X / (kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp));
const size_t compute_stage_y_num = BUFF_DIM_Y / (kHadamardDimension * THREADS_PER_Y);
const size_t in_row_stride = BUFF_DIM_X;
IType* in_sh_ptr = in_shs[stage % 2];
#pragma unroll
for (size_t compute_stage_y = 0; compute_stage_y < compute_stage_y_num; compute_stage_y++) {
const int row_idx_offset = (compute_stage_y * kHadamardDimension * THREADS_PER_Y +
threadIdx.y * kHadamardDimension);
const int in_row_offset = row_idx_offset * in_row_stride;
#pragma unroll
for (size_t compute_stage_x = 0; compute_stage_x < compute_stage_x_num; compute_stage_x++) {
ComputeKernel<IType, kHadamardDimension, BUFF_DIM_Y, BUFF_DIM_X, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>(
had_frag_i, had_frag_t,
in_sh_ptr + in_row_offset +
(compute_stage_x * kHadamardDimension * (THREADS_PER_CHUNK / kThreadsPerWarp)),
local_pre_rht_amax_reg, local_amax_reg, local_amax_t_reg);
}
// Ensure all threads have finished their computation before new data over-writes the shared
// memory.
__syncthreads();
}
}
}
const int warpid = (threadIdx.x + threadIdx.y * blockDim.x) / kThreadsPerWarp;
if constexpr (kReturnPreRhtAmax) {
unpack_max_of_packed_bf16(local_pre_rht_amax_reg, local_pre_rht_amax);
}
if constexpr (kReturnIdentityAmax) {
unpack_max_of_packed_bf16(local_amax_reg, local_amax);
}
if constexpr (kReturnTransposedAmax) {
unpack_max_of_packed_bf16(local_amax_t_reg, local_amax_t);
}
ReduceMax<kNumWarps, kReturnPreRhtAmax, kReturnIdentityAmax, kReturnTransposedAmax>(
local_pre_rht_amax, local_amax, local_amax_t, max_staging_pre_rht, max_staging_identity,
max_staging_transpose, output_pre_rht_amax_ptr, output_identity_amax_ptr,
output_transpose_amax_ptr, warpid);
destroy_barriers<STAGES_X * STAGES_Y>(mbar, is_master_thread);
#else
NVTE_DEVICE_ERROR("Kernel is only supported on SM 10.0+.");
#endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
}
} // namespace
// broadcast_pre_rht_amax: when it's true, hadamard transform will be disabled
// if at this time, the amax buffers for output expects both amax_rowwise and amax_colwise
// then call MultiAmaxMemcpyD2DKernelPreRHT to D2D copy the amax values
void group_hadamard_transform_amax_graph_safe(const GroupedTensor* input, GroupedTensor* output,
uint16_t random_sign_mask,
uint16_t random_sign_mask_t,
bool broadcast_pre_rht_amax, cudaStream_t stream) {
NVTE_API_CALL(group_hadamard_transform_amax_graph_safe);
#if CUDA_VERSION >= 12080
NVTE_CHECK(input->num_tensors == output->num_tensors,
"Number of input and output tensors must be same.");
NVTE_CHECK(input->has_data(), "Cannot quantize tensor without rowwise data.");
checkCuDriverContext(stream);
bool all_return_pre_rht_amax = output->has_data();
// there is no rowwise RHT transform in current recipe
bool all_return_identity_amax = false;
bool all_return_transposed_amax = output->has_columnwise_data();
NVTE_CHECK(all_return_pre_rht_amax || all_return_identity_amax || all_return_transposed_amax,
"At least one of return_pre_rht_amax, return_identity_amax, or return_transposed_amax "
"must be true");
if (broadcast_pre_rht_amax) {
NVTE_CHECK(all_return_pre_rht_amax,
"broadcast_pre_rht_amax is only supported when we compute pre-RHT amax");
// if all_return_identity_amax and all_return_transposed_amax both are false, there is no need to broadcast anything
broadcast_pre_rht_amax &= (all_return_identity_amax || all_return_transposed_amax);
}
const size_t num_tensors = input->num_tensors;
const size_t first_logical_dim = input->logical_shape.data[0];
const size_t last_logical_dim = input->logical_shape.data[1];
// const size_t elts_total = first_logical_dim * last_logical_dim;
NVTE_CHECK(first_logical_dim % 128 == 0,
"First dimension of a grouped tensor should be divisible by 128.");
NVTE_CHECK(last_logical_dim % 128 == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
float* const amax_rowwise_ptr = reinterpret_cast<float*>(output->amax.dptr);
float* const amax_colwise_ptr = reinterpret_cast<float*>(output->columnwise_amax.dptr);
const int64_t* const offsets_ptr = reinterpret_cast<const int64_t*>(input->tensor_offsets.dptr);
const int64_t* const first_dims_ptr = reinterpret_cast<const int64_t*>(input->first_dims.dptr);
// const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
// some sanity checks
if (all_return_pre_rht_amax) {
NVTE_CHECK(amax_rowwise_ptr != nullptr, "Amax rowwise pointer should not be nullptr.");
}
if (all_return_transposed_amax) {
NVTE_CHECK(amax_colwise_ptr != nullptr, "Amax columnwise pointer should not be nullptr.");
}
// Multi zero out multiple amaxes if needed
dim3 block_setup_amax(kMaxTensorsPerKernel);
dim3 grid_setup_amax(1);
GraphSafeMultiZeroAmaxKernel<<<grid_setup_amax, block_setup_amax, 0, stream>>>(
num_tensors, amax_rowwise_ptr, amax_colwise_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
using IType = bf16;
constexpr int kHadamardDimension = 16;
// four (1x4) 64x64 sub-tiles for ping-pong overlap
constexpr uint64_t kChunkBlockXSmall = 256;
constexpr uint64_t kChunkBlockYSmall = 64;
constexpr uint64_t kBuffDimX = 64;
constexpr uint64_t kBuffDimY = 64;
alignas(64) CUtensorMap tensor_map_input{};
create_2D_tensor_map(
/*tensorMap=*/tensor_map_input,
/*tensor=*/input->data,
/*globalY=*/first_logical_dim,
/*globalX=*/last_logical_dim,
/*shmemY=*/kBuffDimY,
/*shmemX=*/kBuffDimX,
/*stride_elems=*/last_logical_dim,
/*offset_elems=*/0,
/*type_num_bits=*/sizeof(IType) * 8,
/*swizzle=*/CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_128B_ATOM_32B);
constexpr uint64_t kThreadBlockX = 4;
constexpr uint64_t kThreadBlockY = 1;
constexpr uint64_t kNumWarps = kThreadBlockX * kThreadBlockY;
dim3 block(kThreadBlockX * kThreadsPerWarp, kThreadBlockY);
dim3 grid(DIVUP(last_logical_dim, kChunkBlockXSmall),
DIVUP(first_logical_dim, kChunkBlockYSmall));
ShapeRepresentation shape_rep = ShapeRepresentation::VARYING_FIRST_DIM;
if (output->all_same_shape()) {
shape_rep = ShapeRepresentation::SAME_BOTH_DIMS;
} else if (output->all_same_first_dim()) {
shape_rep = ShapeRepresentation::VARYING_LAST_DIM;
} else if (output->all_same_last_dim()) {
shape_rep = ShapeRepresentation::VARYING_FIRST_DIM;
} else if (output->varying_both_dims()) {
shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS;
}
const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS ||
shape_rep == ShapeRepresentation::VARYING_FIRST_DIM);
NVTE_CHECK(is_const_last_dim,
"Currently we only support const last dimension for graph safe hadamard transform.");
TRANSFORMER_ENGINE_SWITCH_CONDITION(
(all_return_transposed_amax && !broadcast_pre_rht_amax), kReturnTransposedAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
(all_return_identity_amax && !broadcast_pre_rht_amax), kReturnIdentityAmax,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
all_return_pre_rht_amax, kReturnPreRhtAmax,
// *2 for ping-pong
size_t in_sh_size = kBuffDimX * kBuffDimY * 2 * sizeof(IType);
size_t mbar_size = sizeof(uint64_t) * (kChunkBlockXSmall / kBuffDimX) *
(kChunkBlockYSmall / kBuffDimY);
size_t shmem_bytes = in_sh_size + mbar_size + kNumWarps * sizeof(float) * 3;
// Add padding in case shmem ptr is not aligned to 128 bytes.
shmem_bytes = (shmem_bytes + 128);
auto kernel = GraphSafeGroupHadamardAmaxTmaKernel<
IType, kHadamardDimension, kChunkBlockYSmall, kChunkBlockXSmall, kBuffDimY,
kBuffDimX, kThreadBlockX * kThreadsPerWarp, kThreadBlockY, kReturnPreRhtAmax,
kReturnIdentityAmax, kReturnTransposedAmax>;
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
shmem_bytes);
kernel<<<grid, block, shmem_bytes, stream>>>(
tensor_map_input, random_sign_mask, random_sign_mask_t, shape_rep, num_tensors,
first_logical_dim, last_logical_dim, offsets_ptr, first_dims_ptr,
amax_rowwise_ptr, amax_colwise_ptr);
if (broadcast_pre_rht_amax) {
GraphSafeMultiAmaxMemcpyD2DKernelPreRHT<<<grid_setup_amax, block_setup_amax, 0,
stream>>>(num_tensors, amax_rowwise_ptr,
amax_colwise_ptr);
})));
NVTE_CHECK_CUDA(cudaGetLastError());
#else
NVTE_ERROR("Hadamard transform requires CUDA 12.8+, but compile-time CUDA version is ",
CUDA_VERSION);
#endif // CUDA_VERSION >= 12080
}
} // namespace transformer_engine
void nvte_group_hadamard_transform_amax_graph_safe(const NVTEGroupedTensor input,
NVTEGroupedTensor output, int random_sign_mask,
int random_sign_mask_t, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_hadamard_transform_amax_graph_safe);
using namespace transformer_engine;
GroupedTensor* input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor* output_tensor = convertNVTEGroupedTensorCheck(output);
if (input_tensor->num_tensors == 0) {
return;
}
// Call the group tensor Hadamard transform amax implementation.
group_hadamard_transform_amax_graph_safe(
input_tensor, output_tensor, static_cast<uint16_t>(random_sign_mask),
static_cast<uint16_t>(random_sign_mask_t), false, stream);
}
// Grouped-tensor amax without doing hadamard transform
void nvte_group_amax_graph_safe(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_group_amax_graph_safe);
using namespace transformer_engine;
GroupedTensor* input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor* output_tensor = convertNVTEGroupedTensorCheck(output);
if (input_tensor->num_tensors == 0) {
return;
}
group_hadamard_transform_amax_graph_safe(input_tensor, output_tensor, 0, 0, true, stream);
}
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include <cuda.h>
#include <cudaTypedefs.h>
#include <cuda_bf16.h>
#include <cuda_pipeline.h>
#include <cuda_runtime.h>
#include <cutlass/arch/barrier.h>
#include <transformer_engine/hadamard_transform.h>
#include <cuda/barrier>
#include <cute/algorithm/gemm.hpp>
#include <cute/arch/cluster_sm90.hpp>
#include <cute/tensor.hpp>
#include "common/common.h"
#include "common/util/cuda_runtime.h"
#include "common/util/curanddx.hpp"
#include "common/util/ptx.cuh"
#include "common/utils.cuh"
#include "customized_pipeline.cuh"
#include "cutlass/arch/barrier.h"
#include "cutlass/arch/reg_reconfig.h"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/detail/sm100_blockscaled_layout.hpp"
#include "cutlass/fast_math.h"
#include "cutlass/float8.h"
#include "cutlass/float_subbyte.h"
#include "cutlass/gemm/collective/builders/sm100_common.inl"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/platform/platform.h"
#include "cutlass/util/GPU_Clock.hpp"
#include "cutlass/util/command_line.h"
#include "cutlass/util/print_error.hpp"
namespace transformer_engine {
namespace detail {
namespace {
using namespace cute;
// Ensure Tensor refers to cute::Tensor, not transformer_engine::Tensor
using cute::Tensor;
constexpr int kMaxTensorsPerKernel = 64;
constexpr int kNVFP4BlockSize = 16;
enum ShapeRepresentation {
SAME_BOTH_DIMS = 0,
VARYING_FIRST_DIM = 1,
VARYING_LAST_DIM = 2,
VARYING_BOTH_DIMS = 3
};
__device__ __forceinline__ size_t get_current_tensor_id(
const ShapeRepresentation shape_rep, const size_t num_tensors, const size_t current_offset,
const size_t first_logical_dim, const size_t last_logical_dim,
const int64_t *const __restrict__ offsets_ptr) {
if (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) {
const size_t current_row = current_offset / last_logical_dim;
const size_t rows_per_tensor = first_logical_dim / num_tensors;
return current_row / rows_per_tensor;
} else {
// upper_bound(offsets, current_offset) - 1 in range i in [0..num_tensors)
size_t low = 0;
size_t hi = num_tensors; // half-open [low, hi)
while (low < hi) {
const size_t mid = low + (hi - low) / 2;
const size_t mid_offset = static_cast<size_t>(offsets_ptr[mid]);
if (mid_offset <= current_offset) {
low = mid + 1;
} else {
hi = mid;
}
}
// low = first index where offsets[low] > current_offset (or low == num_tensors)
// id = low - 1, but need to evaluate if current_offset < offsets[0]
return (low == 0) ? 0 : (low - 1);
}
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 8> StochasticNumericConverterBase(
cutlass::Array<float, 8> const &input, cutlass::Array<uint32_t, 2> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 8>;
result_type output;
auto output_ptr = reinterpret_cast<uint16_t *>(&output);
constexpr bool has_rs = ARCH_HAS_STOCHASTIC_ROUNDING;
if constexpr (has_rs) {
asm volatile(
"{\n"
"cvt.rs.satfinite.e2m1x4.f32 %0, {%5, %4, %3, %2}, %10;\n"
"cvt.rs.satfinite.e2m1x4.f32 %1, {%9, %8, %7, %6}, %11;\n"
"}"
: "=h"(output_ptr[0]), "=h"(output_ptr[1])
: "f"(input[0]), "f"(input[1]), "f"(input[2]), "f"(input[3]), "f"(input[4]), "f"(input[5]),
"f"(input[6]), "f"(input[7]), "r"(rbits[0]), "r"(rbits[1]));
} else {
NVTE_DEVICE_ERROR(
"FP4 cvt PTX instructions are architecture-specific. "
"Try recompiling with sm_XXXa instead of sm_XXX.");
}
return output;
}
CUTLASS_DEVICE
cutlass::Array<cutlass::float_e2m1_t, 16> StochasticNumericConverter(
cutlass::Array<float, 16> const &input, cutlass::Array<uint32_t, 4> const &rbits) {
using result_type = cutlass::Array<cutlass::float_e2m1_t, 16>;
result_type output;
cutlass::Array<cutlass::float_e2m1_t, 8> *result_ptr =
reinterpret_cast<cutlass::Array<cutlass::float_e2m1_t, 8> *>(&output);
cutlass::Array<float, 8> const *source_ptr =
reinterpret_cast<cutlass::Array<float, 8> const *>(&input);
cutlass::Array<uint32_t, 2> const *rbits_ptr =
reinterpret_cast<cutlass::Array<uint32_t, 2> const *>(&rbits);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < 2; i++) {
result_ptr[i] = StochasticNumericConverterBase(source_ptr[i], rbits_ptr[i]);
}
return output;
}
template <class ElementA, class ElementB, class ASmemLayout, class BSmemLayout, class ClusterShape,
int AccumulatorPipelineStageCount_, int EpilogueUnrollFactor_,
int SchedulerPipelineStageCount_>
struct SharedStorage {
static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
static int constexpr EpilogueUnrollFactor = EpilogueUnrollFactor_;
using AtomThrShapeMNK = cute::Shape<_1, _1, _1>;
using AccumulatorPipeline =
cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / EpilogueUnrollFactor,
AtomThrShapeMNK>;
using AccumulatorPipelineStorage = typename AccumulatorPipeline::SharedStorage;
static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{});
using MainloopPipeline =
cutlass::detail::CustomizedPipelineTmaUmmaAsync<MainloopPipelineStageCount, Shape<_1, _1, _1>,
AtomThrShapeMNK>;
using MainloopPipelineStorage = typename MainloopPipeline::SharedStorage;
using SchedPipeline = cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount_, ClusterShape>;
using SchedPipelineStorage = typename SchedPipeline::SharedStorage;
using SchedThrottlePipeline = cutlass::PipelineAsync<SchedulerPipelineStageCount_>;
using SchedThrottlePipelineStorage = typename SchedThrottlePipeline::SharedStorage;
struct TensorStorage : cute::aligned_struct<128, _1> {
cute::array_aligned<ElementA, cute::cosize_v<ASmemLayout>> smem_A;
cute::array_aligned<ElementB, cute::cosize_v<BSmemLayout>> smem_B;
} tensors;
alignas(16) AccumulatorPipelineStorage accumulator;
alignas(16) MainloopPipelineStorage mainloop;
alignas(16) cute::uint64_t tma_barrier[1];
alignas(16) SchedPipelineStorage sched;
alignas(16) SchedThrottlePipelineStorage sched_throttle;
alignas(16) int32_t atomic_tile_id[SchedulerPipelineStageCount_];
alignas(16) float global_a_amax[kMaxTensorsPerKernel];
alignas(16) float global_d_amax[kMaxTensorsPerKernel];
uint32_t atomic_tile_counter[SchedulerPipelineStageCount_];
uint32_t tmem_base_ptr;
};
// Main RHT GEMM kernel entry -- highly templated for flexible architecture/config support
template <class MShape, class NShape, class KShape, class ClusterShape, class ClusterTileShape,
class TA, class AStride, class ASmemLayout, class TmaLoadA, class TB, class BStride,
class BSmemLayout, class TmaLoadB, class TD, class DStride, class DSmemLayout, class TSFD,
class TSFDLayout, class TQA, class QAStride, class TSFA, class TSFALayout, class TiledMMA,
int AccumulatorPipelineStageCount_, int SchedulerPipelineStageCount_,
bool kEnableStochasticRounding_ = false, bool kEnableRHTColQuant_ = true,
bool kEnableRowQuant_ = true, bool kEnableSwizzleSFOutput_ = false,
bool kUseFastMath_ = false>
__launch_bounds__(512, 1) __global__ static void group_row_col_rht_gemm_device_graph_safe(
MShape M, NShape packed_N, KShape K, ClusterShape cluster_shape, ClusterTileShape cluster_tile,
TA const *A, AStride dA, ASmemLayout sAlayout, CUTE_GRID_CONSTANT TmaLoadA const tma_load_a,
TB const *B, BStride dB, BSmemLayout sBlayout, CUTE_GRID_CONSTANT TmaLoadB const tma_load_b,
TQA *QA, QAStride dQA, TSFA *SFA, TSFALayout sfa_layout, TQA *QA_COLWISE, TSFA *SFA_COLWISE,
float *amax_rowwise, float *amax_colwise, const int64_t *offsets, const int64_t *first_dims,
size_t num_tensors, ShapeRepresentation shape_rep, uint32_t *tile_scheduler_workspace,
TiledMMA mma, const size_t *rng_state) {
using namespace cute;
// Abort immediately if compilation is not supported
constexpr bool is_blackwell_arch = ARCH_BLACKWELL_FAMILY;
if constexpr (!is_blackwell_arch) {
NVTE_DEVICE_ERROR(
"group_row_col_rht_gemm_device_graph_safe is only supported on Blackwell "
"with architecture-specific compilation. "
"Try recompiling with sm_100a or similar.");
return;
}
static_assert(kEnableRHTColQuant_ || kEnableRowQuant_,
"group_row_col_rht_gemm_device_graph_safe must generate row-wise "
"and/or column-wise output.");
#if !defined(CUTLASS_ARCH_CLC_ENABLED)
CUTLASS_NOT_IMPLEMENTED();
return;
#endif
using X = Underscore;
// Accumulator data type for main computation
using ElementAccumulator = float;
static int constexpr K_PIPE_MAX = size<3>(ASmemLayout{});
using AtomThrShapeMNK = Shape<decltype(shape<0>(typename TiledMMA::ThrLayoutVMNK{})), _1, _1>;
static uint32_t constexpr kTmaTransactionBytes = cutlass::bits_to_bytes(
size(AtomThrShapeMNK{}) * cosize(take<0, 3>(ASmemLayout{})) * cute::sizeof_bits_v<TA>);
static constexpr bool kEnableStochasticRounding = kEnableStochasticRounding_;
static constexpr bool kEnableRHTColQuant = kEnableRHTColQuant_;
static constexpr bool kEnableRowQuant = kEnableRowQuant_;
static constexpr bool kEnableSwizzleSFOutput = kEnableSwizzleSFOutput_;
static constexpr bool kUseFastMath = kUseFastMath_;
// Constant for RHT tensor processing (tile size etc)
static int constexpr RhtTensorSize = 16;
// Get the total number of tokens to process
// Note that here M is the hidden size, which is the last logical dimension of the input tensor x
// The kernel is designed in column major, so M is the hidden size
size_t sum_token_dims = offsets[num_tensors] / M;
// Transaction bytes for TMA transfer on RHT tensor blocks
static int constexpr kTmaRhtTensorTransactionBytes =
cutlass::bits_to_bytes(RhtTensorSize * RhtTensorSize * cute::sizeof_bits_v<TB>);
static int constexpr AccumulatorPipelineStageCount = AccumulatorPipelineStageCount_;
static int constexpr SchedulerPipelineStageCount = SchedulerPipelineStageCount_;
// Mainloop pipeline stage calculation, vectorization parameters for scaling factors
static int constexpr MainloopPipelineStageCount = size<3>(ASmemLayout{});
static int constexpr SFVecSize = 16;
// Swizzle output layout for scaling factor arrays
using SwizzledSFALayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::MN>::SfAtom;
using SwizzledSFDLayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::K>::SfAtom;
// Mainloop pipeline types for TMA async execution and epilogue cluster scheduling
using MainloopPipeline =
cutlass::detail::CustomizedPipelineTmaUmmaAsync<MainloopPipelineStageCount, ClusterShape,
AtomThrShapeMNK>;
using MainloopPipelineState = typename MainloopPipeline::PipelineState;
using SchedPipeline = cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount, ClusterShape>;
using SchedPipelineState = typename SchedPipeline::PipelineState;
using SchedThrottlePipeline = cutlass::PipelineAsync<SchedulerPipelineStageCount>;
using SchedThrottlePipelineState = typename SchedThrottlePipeline::PipelineState;
static_assert(ClusterShape{} == Shape<_1, _1, _1>{}, "ClusterShape must be Shape<_1,_1,_1>");
using TmemAllocator = cute::TMEM::Allocator1Sm;
static int constexpr VectorSize = RhtTensorSize;
// Compile-time safety: static shapes required for shared memory layouts
CUTE_STATIC_ASSERT(is_static<ASmemLayout>::value);
CUTE_STATIC_ASSERT(is_static<BSmemLayout>::value);
// CUTE_STATIC_ASSERT(is_static<DSmemLayout>::value);
auto cluster_size = size<0>(cluster_shape);
auto mainloop_tiler = Shape<_128, _16, _128>{};
auto epilogue_tiler = Shape<_128, _128, _128>{};
static int constexpr EpilogueUnrollFactor = size<2>(epilogue_tiler) / size<2>(cluster_tile);
// Get the appropriate blocks for this Cluster
dim3 cluster_coord_in_grid = cluster_id_in_grid();
// Total number of k-tiles
int const K_TILE_MAX = min(packed_N, K) / size<2>(epilogue_tiler);
struct TileScheduler {
uint32_t tiles_in_m = 0;
uint32_t tiles_in_n = 0;
uint32_t linear_idx = 0;
uint32_t next_linear_idx = 0;
uint32_t start_idx = 0;
uint32_t tile_m_idx = 0;
uint32_t tile_n_idx = 0;
int k_tile_max = 0;
uint32_t *atomic_tile_index_;
uint32_t *smem_tile_counter;
uint32_t atomic_offset;
cutlass::FastDivmodU64 divmod_tiles_in_m;
CUTLASS_DEVICE TileScheduler(uint32_t tiles_m, uint32_t tiles_n, int kmax,
uint32_t *atomic_tile_index, uint32_t *smem_tile_counter)
: tiles_in_m(tiles_m),
tiles_in_n(tiles_n),
linear_idx(blockIdx.x),
next_linear_idx(blockIdx.x),
start_idx(blockIdx.x),
k_tile_max(kmax),
atomic_tile_index_(atomic_tile_index),
smem_tile_counter(smem_tile_counter),
atomic_offset(gridDim.x),
divmod_tiles_in_m(uint64_t(tiles_m)) {
update_tile_idx();
}
CUTLASS_DEVICE void update_tile_idx() {
uint64_t q, r;
divmod_tiles_in_m(q, r, uint64_t(linear_idx));
tile_m_idx = static_cast<uint32_t>(r);
tile_n_idx = static_cast<uint32_t>(q) * uint32_t(k_tile_max);
}
CUTLASS_DEVICE uint32_t tile_m() const { return tile_m_idx; }
CUTLASS_DEVICE uint32_t tile_n_base() const { return tile_n_idx; }
CUTLASS_DEVICE uint32_t tiles_m() const { return tiles_in_m; }
CUTLASS_DEVICE uint32_t tiles_n() const { return tiles_in_n; }
CUTLASS_DEVICE bool is_valid() const {
return cute::elem_less(cute::make_coord(tile_m(), tile_n_base()),
cute::make_coord(tiles_in_m, tiles_in_n));
}
CUTLASS_DEVICE bool is_first_wave() const { return linear_idx == start_idx; }
CUTLASS_DEVICE uint32_t get_linear_tile_idx() const { return linear_idx; }
// Fetch a new tile_id using atomics.
CUTLASS_DEVICE uint32_t fetch_tile_id_counter(int pred) {
uint32_t tile_id_counter = 0;
asm volatile(
"{\n\t"
".reg .pred p;\n\t"
"setp.eq.u32 p, %2, 1;\n\t"
"@p atom.global.add.u32 %0, [%1], 1; \n\t"
"}"
: "=r"(tile_id_counter)
: "l"(atomic_tile_index_), "r"(pred));
return tile_id_counter;
}
CUTLASS_DEVICE auto fetch_next_work(SchedPipeline &sched_pipeline,
SchedPipelineState sched_pipeline_consumer_state) {
sched_pipeline.consumer_wait(sched_pipeline_consumer_state);
next_linear_idx = smem_tile_counter[sched_pipeline_consumer_state.index()];
cutlass::arch::fence_view_async_shared();
sched_pipeline.consumer_release(sched_pipeline_consumer_state);
return;
}
CUTLASS_DEVICE auto advance_to_next_work(SchedPipeline &sched_pipeline,
SchedPipelineState sched_pipeline_producer_state) {
uint32_t mbarrier_addr = sched_pipeline.producer_get_barrier(sched_pipeline_producer_state);
// Wait for clcID buffer to become empty with a flipped phase
sched_pipeline.producer_acquire(sched_pipeline_producer_state);
auto is_leading_thread = cute::elect_one_sync();
uint32_t tile_id_counter = fetch_tile_id_counter(is_leading_thread) + atomic_offset;
uint32_t smem_addr =
cute::cast_smem_ptr_to_uint(&smem_tile_counter[sched_pipeline_producer_state.index()]);
if (is_leading_thread) {
cute::store_shared_remote(tile_id_counter, smem_addr, mbarrier_addr, 0);
}
++sched_pipeline_producer_state;
return sched_pipeline_producer_state;
}
CUTLASS_DEVICE auto update_work_tile_info() {
linear_idx = next_linear_idx;
update_tile_idx();
return;
}
};
// Allocate and alias shared memory to the kernel's shared storage type
extern __shared__ char shared_memory[];
using SharedStorage =
SharedStorage<TA, TB, ASmemLayout, BSmemLayout, ClusterShape, AccumulatorPipelineStageCount,
EpilogueUnrollFactor, SchedulerPipelineStageCount>;
SharedStorage &shared_storage = *reinterpret_cast<SharedStorage *>(shared_memory);
// Compute the number of tiles in M and N after tiling and assign scheduler
uint32_t tiles_in_m = uint32_t(size(ceil_div(M, size<0>(cluster_tile))));
uint32_t tiles_in_n = uint32_t(size(ceil_div(sum_token_dims, size<2>(epilogue_tiler))));
TileScheduler scheduler(tiles_in_m, tiles_in_n, K_TILE_MAX, tile_scheduler_workspace,
shared_storage.atomic_tile_counter);
int block_rank_in_cluster = cute::block_rank_in_cluster();
// Shapes for accumulated tiles in mainloop and epilogue
auto acc_shape_mma = make_shape(take<0, 2>(mainloop_tiler), _1{}, _1{});
auto acc_shape_epilogue = make_shape(take<0, 2>(epilogue_tiler), _1{}, _1{});
// Shape of the accumulator fragment for the main loop pipeline, with pipeline stages appended
auto acc_mainloop_pipelined_shape = append(acc_shape_mma, Int<AccumulatorPipelineStageCount>{});
auto bulk_tmem_mma = TiledMMA::make_fragment_C(acc_mainloop_pipelined_shape);
// Number of threads assigned for various epilogue roles depending on quantization settings
static int constexpr NumEpilogueColQuantThreadCount = kEnableRHTColQuant ? 128 : 0;
static int constexpr NumEpilogueRowQuantThreadCount = kEnableRowQuant ? 256 : 0;
static int constexpr NumMmaThreadCount = kEnableRHTColQuant ? 32 : 0;
static int constexpr NumMmaIssueThreadCount = kEnableRHTColQuant ? 1 : 0;
static int constexpr NumSchedThreads = 32;
static int constexpr NumMainloopLoadThreads = 32;
static int constexpr NumEpilogueThreads =
NumEpilogueColQuantThreadCount + NumEpilogueRowQuantThreadCount;
TmemAllocator tmem_allocator{};
cutlass::arch::NamedBarrier tmem_allocation_result_barrier(
NumMmaThreadCount + NumEpilogueColQuantThreadCount,
cutlass::arch::ReservedNamedBarriers::TmemAllocBarrier);
int warp_idx = cutlass::canonical_warp_idx_sync();
// warp assignment
bool is_mma_warp = (warp_idx == 0);
bool is_dma_warp = (warp_idx == 1);
bool is_sched_warp = (warp_idx == 2);
bool is_epilogue_col_quant_warp = (warp_idx >= 4 && warp_idx <= 7);
bool is_epilogue_row_quant_warp = (warp_idx >= 8 && warp_idx <= 15);
typename MainloopPipeline::Params mainloop_pipeline_params;
if (is_dma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer;
}
if (is_mma_warp) {
mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer;
}
mainloop_pipeline_params.is_leader = cute::elect_one_sync() && is_dma_warp;
mainloop_pipeline_params.transaction_bytes = kTmaTransactionBytes;
mainloop_pipeline_params.initializing_warp = 0;
mainloop_pipeline_params.num_consumers = NumEpilogueRowQuantThreadCount + NumMmaIssueThreadCount;
MainloopPipeline mainloop_pipeline(shared_storage.mainloop, mainloop_pipeline_params,
cluster_shape, cute::true_type{}, // Perform barrier init
cute::true_type{}); // Delay mask calculation
MainloopPipelineState mainloop_pipe_consumer_state;
MainloopPipelineState mainloop_pipe_producer_state =
cutlass::make_producer_start_state<MainloopPipeline>();
using AccumulatorPipeline =
cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / EpilogueUnrollFactor,
AtomThrShapeMNK>;
using AccumulatorPipelineState = typename AccumulatorPipeline::PipelineState;
using AccumulatorPipelineInitBarriers = cute::bool_constant<kEnableRHTColQuant>;
AccumulatorPipelineState accumulator_pipe_consumer_state;
AccumulatorPipelineState accumulator_pipe_producer_state =
cutlass::make_producer_start_state<AccumulatorPipeline>();
typename AccumulatorPipeline::Params accumulator_pipeline_params;
if (is_mma_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Producer;
}
if (is_epilogue_col_quant_warp) {
accumulator_pipeline_params.role = AccumulatorPipeline::ThreadCategory::Consumer;
}
// Only one producer thread arrives on this barrier.
accumulator_pipeline_params.producer_arv_count = 1;
accumulator_pipeline_params.consumer_arv_count =
size(AtomThrShapeMNK{}) * NumEpilogueColQuantThreadCount;
accumulator_pipeline_params.initializing_warp = 1;
AccumulatorPipeline accumulator_pipeline(shared_storage.accumulator, accumulator_pipeline_params,
cluster_shape, AccumulatorPipelineInitBarriers{},
cute::true_type{}); // Delay mask calculation
typename SchedPipeline::Params sched_pipeline_params;
if (is_sched_warp) {
sched_pipeline_params.role = SchedPipeline::ThreadCategory::ProducerConsumer;
} else {
sched_pipeline_params.role = SchedPipeline::ThreadCategory::Consumer;
}
sched_pipeline_params.producer_blockid = 0;
sched_pipeline_params.producer_arv_count = 1;
sched_pipeline_params.consumer_arv_count =
NumSchedThreads +
cluster_size * (NumMainloopLoadThreads + NumEpilogueThreads + NumMmaThreadCount);
sched_pipeline_params.transaction_bytes = sizeof(uint32_t);
sched_pipeline_params.initializing_warp = 3;
SchedPipeline sched_pipeline(shared_storage.sched, sched_pipeline_params, cluster_shape);
SchedPipelineState sched_pipeline_consumer_state;
SchedPipelineState sched_pipeline_producer_state =
cutlass::make_producer_start_state<SchedPipeline>();
typename SchedThrottlePipeline::Params sched_throttle_pipeline_params;
if (is_dma_warp) {
sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Producer;
}
if (is_sched_warp) {
sched_throttle_pipeline_params.role = SchedThrottlePipeline::ThreadCategory::Consumer;
}
sched_throttle_pipeline_params.producer_arv_count = NumMainloopLoadThreads;
sched_throttle_pipeline_params.consumer_arv_count = NumSchedThreads;
sched_throttle_pipeline_params.dst_blockid = 0;
sched_throttle_pipeline_params.initializing_warp = 4;
SchedThrottlePipeline sched_throttle_pipeline(shared_storage.sched_throttle,
sched_throttle_pipeline_params);
SchedThrottlePipelineState sched_pipeline_throttle_consumer_state;
SchedThrottlePipelineState sched_pipeline_throttle_producer_state =
cutlass::make_producer_start_state<SchedThrottlePipeline>();
if (warp_idx == 2 && elect_one_sync()) {
cute::initialize_barrier(shared_storage.tma_barrier[0], /* num_threads */ 1);
}
__syncthreads();
// Warp group roles: DMA (global->shared copy), MMA (tensor core gemm), scheduler, column quantizer, row quantizer
if (is_dma_warp) {
// Warp responsible for loading input from global to shared memory using TMA (Tensor Memory Access).
cutlass::arch::warpgroup_reg_dealloc<32>();
// Get TMA tensors for input matrix A and B (Hadamard/transform matrix) from global memory.
Tensor mA = tma_load_a.get_tma_tensor(make_shape(M, packed_N));
Tensor mB = tma_load_b.get_tma_tensor(make_shape(RhtTensorSize, RhtTensorSize));
// Partition tensors for tiling according to the mainloop and cluster tilers.
Tensor gA_mk = local_tile(mA, mainloop_tiler, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor gB_nk =
local_tile(mB, cluster_tile, make_coord(_, _, _), Step<X, _1, _1>{}); // (BLK_N,BLK_K,k)
// Shared memory tensors for pipeline
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()),
sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
// Determine warp/tile positioning
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
// Partition global to local fragments for A and B
Tensor tCgA = thr_mma.partition_A(gA_mk); // (MMA,MMA_M,MMA_K,k)
Tensor tCgB = thr_mma.partition_B(gB_nk); // (MMA,MMA_N,MMA_K,k)
Layout cta_layout_mnk = make_layout(cluster_shape);
Layout cta_layout_vmnk =
tiled_divide(cta_layout_mnk, make_tile(typename TiledMMA::AtomThrID{}));
auto cta_coord_vmnk = cta_layout_vmnk.get_flat_coord(block_rank_in_cluster);
auto [tAgA, tAsA] =
tma_partition(tma_load_a, get<2>(cta_coord_vmnk), make_layout(size<2>(cta_layout_vmnk)),
group_modes<0, 3>(tCsA), group_modes<0, 3>(tCgA));
auto [tBgB, tBsB] =
tma_partition(tma_load_b, get<1>(cta_coord_vmnk), make_layout(size<1>(cta_layout_vmnk)),
group_modes<0, 3>(tCsB), group_modes<0, 3>(tCgB));
uint16_t tma_mcast_mask_a = create_tma_multicast_mask<2>(cta_layout_vmnk, cta_coord_vmnk);
uint16_t tma_mcast_mask_b = create_tma_multicast_mask<1>(cta_layout_vmnk, cta_coord_vmnk);
if constexpr (kEnableRHTColQuant) {
if (elect_one_sync()) {
cute::set_barrier_transaction_bytes(shared_storage.tma_barrier[0],
kTmaRhtTensorTransactionBytes);
copy(tma_load_b.with(shared_storage.tma_barrier[0], tma_mcast_mask_b), tBgB(_, 0, 0),
tBsB(_, 0));
}
}
do {
// is_first_wave indicates whether this scheduler wave is the first among a group.
bool is_first_wave = scheduler.is_first_wave();
uint32_t skip_wait = is_first_wave;
auto tAgA_mk = tAgA(_, scheduler.tile_m(), _);
int k_tile = 0;
sched_throttle_pipeline.producer_acquire(sched_pipeline_throttle_producer_state);
sched_throttle_pipeline.producer_commit(sched_pipeline_throttle_producer_state);
++sched_pipeline_throttle_producer_state;
CUTLASS_PRAGMA_NO_UNROLL
while (k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n()) {
int k_tile_idx_n = scheduler.tile_n_base() + k_tile;
++k_tile;
skip_wait = (is_first_wave && k_tile < MainloopPipelineStageCount);
mainloop_pipeline.producer_acquire(mainloop_pipe_producer_state);
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType *tma_barrier =
mainloop_pipeline.producer_get_barrier(mainloop_pipe_producer_state);
int write_stage = mainloop_pipe_producer_state.index();
++mainloop_pipe_producer_state;
if (cute::elect_one_sync()) {
copy(tma_load_a.with(*tma_barrier, tma_mcast_mask_a), tAgA_mk(_, k_tile_idx_n),
tAsA(_, write_stage));
}
}
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
scheduler.update_work_tile_info();
// scheduler.advance();
} while (scheduler.is_valid());
mainloop_pipeline.producer_tail(mainloop_pipe_producer_state);
} else if (is_mma_warp) {
// This warp executes the main tensor core matrix-multiply-accumulate for the Hadamard transform.
cutlass::arch::warpgroup_reg_dealloc<32>();
if constexpr (kEnableRHTColQuant) {
// Setup shared memory fragments for A and B tiles.
Tensor tCsA = make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout); // (MMA,MMA_M,MMA_N,PIPE)
Tensor tCsB = make_tensor(make_smem_ptr(shared_storage.tensors.smem_B.data()),
sBlayout); // (MMA,MMA_N,MMA_K,PIPE)
int block_rank_in_cluster = cute::block_rank_in_cluster();
ThrMMA thr_mma = mma.get_slice(block_rank_in_cluster); // blk idx
// Allocate "fragments" -- these are actually umma smem descriptors
Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_M,MMA_K,PIPE)
mma.accumulate_ = UMMA::ScaleOut::Zero;
tmem_allocator.allocate(TmemAllocator::Sm100TmemCapacityColumns,
&shared_storage.tmem_base_ptr);
__syncwarp();
tmem_allocation_result_barrier.arrive();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_mma.data() = tmem_base_ptr;
// Wait until the B (Hadamard) tensor copy is complete
cute::wait_barrier(shared_storage.tma_barrier[0], 0 /*tma_phase_bit*/);
do {
uint32_t skip_wait = K_TILE_MAX <= 0;
auto barrier_token =
mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
CUTLASS_PRAGMA_NO_UNROLL
for (int k_tile = 0;
k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) {
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
int read_stage = mainloop_pipe_consumer_state.index();
auto tCrA_mk = tCrA(_, _, _, read_stage);
auto tCrB_nk = tCrB(_, _, 0, 0);
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < size<2>(tCrA) / EpilogueUnrollFactor; ++k_block) {
int accumulator_k_block =
accumulator_pipe_producer_state.index() * EpilogueUnrollFactor;
int tCrA_k_block = k_block * EpilogueUnrollFactor;
accumulator_pipeline.producer_acquire(accumulator_pipe_producer_state);
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < EpilogueUnrollFactor; i++) {
auto accumulators = bulk_tmem_mma(_, _, _, accumulator_k_block + i);
gemm(mma, tCrA_mk(_, _, tCrA_k_block + i), tCrB_nk, accumulators);
}
accumulator_pipeline.producer_commit(accumulator_pipe_producer_state);
++accumulator_pipe_producer_state;
}
auto curr_mainloop_pipe_consumer_state = mainloop_pipe_consumer_state;
++mainloop_pipe_consumer_state;
++k_tile;
skip_wait = k_tile >= K_TILE_MAX;
mainloop_pipeline.umma_consumer_release(curr_mainloop_pipe_consumer_state);
barrier_token =
mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state, skip_wait);
}
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
tmem_allocator.release_allocation_lock();
accumulator_pipeline.producer_tail(accumulator_pipe_producer_state);
tmem_allocator.free(tmem_base_ptr, TmemAllocator::Sm100TmemCapacityColumns);
}
} else if (is_sched_warp) {
// Scheduler warp manages tile assignment and pipeline progress for warps
cutlass::arch::warpgroup_reg_dealloc<32>();
do {
sched_throttle_pipeline.consumer_wait(sched_pipeline_throttle_consumer_state);
sched_throttle_pipeline.consumer_release(sched_pipeline_throttle_consumer_state);
++sched_pipeline_throttle_consumer_state;
sched_pipeline_producer_state =
scheduler.advance_to_next_work(sched_pipeline, sched_pipeline_producer_state);
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
} else if (is_epilogue_col_quant_warp) {
// Warp responsible for quantizing output of Hadamard transform to FP4 for columnwise usage,
// and writing result tensors/scales to global memory.
cutlass::arch::warpgroup_reg_alloc<192>();
if constexpr (kEnableRHTColQuant) {
using TMEM_LOAD_NEW = cute::SM100::TMEM::LOAD::SM100_TMEM_LOAD_32dp32b64x;
auto acc_epilogue_pipelined_shape =
append(acc_shape_epilogue, Int<AccumulatorPipelineStageCount / EpilogueUnrollFactor>{});
auto bulk_tmem_epilogue_layout = make_layout(
acc_epilogue_pipelined_shape,
make_stride(stride<0>(bulk_tmem_mma), Int<0>{}, Int<0>{}, size<1>(epilogue_tiler)));
auto bulk_tmem_epilogue = make_tensor(make_tmem_ptr<uint32_t>(), bulk_tmem_epilogue_layout);
// Use 256-bit fragments for aligned bulk stores
static int constexpr FragmentSize = 256 / sizeof_bits_v<TD>;
// Wait for TMEM allocation for this pipeline to finish
tmem_allocation_result_barrier.arrive_and_wait();
uint32_t tmem_base_ptr = shared_storage.tmem_base_ptr;
bulk_tmem_epilogue.data() = tmem_base_ptr;
int global_thread_idx = threadIdx.x;
int local_thread_idx = global_thread_idx % cutlass::NumThreadsPerWarpGroup;
// g2s load all global_d_amax
CUTLASS_PRAGMA_NO_UNROLL
for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueColQuantThreadCount) {
shared_storage.global_d_amax[g] = __ldg(reinterpret_cast<float *>(amax_colwise + g));
}
size_t rng_seed = 0;
size_t rng_offset = 0;
// Setup RNG for stochastic rounding
if constexpr (kEnableStochasticRounding) {
rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0;
rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0;
}
// TODO(zhongbo): double check the logic here
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
// Determine quantization scale factor layouts/output splits for this group
TSFDLayout sfd_layout;
int cur_N = static_cast<int>(first_dims[group_idx]);
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// Build output tensors for columns and their quant scales
// TODO(zhongbo): double check the logic here
Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>(
reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)),
make_shape(M, cur_N), DStride{}); // (M,packed_N)
Tensor gD_mn =
local_tile(mD, epilogue_tiler, make_coord(_, _, _), Step<_1, _1, X>{}); // (BLK_M,BLK_N)
// for every tensor [x, y] row major, x y both a multiple of 128
// both of its rowwise and colwise scaling factors will have exactly x * y / 16 elements in FP8 E4M3
Tensor mSFD = make_tensor(
make_gmem_ptr<TSFD>(reinterpret_cast<TSFD *>(reinterpret_cast<char *>(SFA_COLWISE) +
offsets[group_idx] / kNVFP4BlockSize)),
sfd_layout);
Tensor gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _),
Step<_1, _1, X>{}); // (BLK_M,BLK_N)
Tensor gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler));
// Setup tile-level TMEM (t2r) and global memory (r2g) copy descriptors
auto tiled_t2r = make_tmem_copy(TMEM_LOAD_NEW{}, bulk_tmem_epilogue(_, _, _, _0{}));
auto tiled_r2g =
make_tiled_copy_D(Copy_Atom<SM100_STORE_256bit_CACHE_NOALLOCATION, TD>{}, tiled_t2r);
auto thr_t2r = tiled_t2r.get_slice(local_thread_idx);
auto thr_r2g = tiled_r2g.get_slice(local_thread_idx);
cutlass::arch::NamedBarrier::sync(NumEpilogueColQuantThreadCount,
cutlass::arch::ReservedNamedBarriers::EpilogueBarrier);
// Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release}
static constexpr float fp4_max = 6.0f;
static constexpr float fp8_max = 448.0f;
static constexpr float fp4_max_inv = 1.0f / fp4_max;
float c_global_amax_val = shared_storage.global_d_amax[group_idx];
float global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
float global_decode_scale = 1.0f / global_encode_scale;
// Scaling factor for fast math path
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
do {
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
CUTLASS_PRAGMA_NO_UNROLL
for (int k_tile = 0;
k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();
++k_tile) {
int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler);
// TODO(zhongbo): double check the logic here
int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors,
global_tile_n_offset * M, packed_N, M, offsets);
if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
c_global_amax_val = shared_storage.global_d_amax[group_idx];
// update amax
global_encode_scale = c_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / c_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
// TODO(zhongbo): double check the logic here
cur_N = first_dims[group_idx];
if constexpr (kEnableSwizzleSFOutput) {
sfd_layout =
tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{});
} else {
sfd_layout =
make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)),
make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{})));
}
// update tensor
mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>(
reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)),
make_shape(M, cur_N), DStride{});
gD_mn = local_tile(mD, epilogue_tiler, make_coord(_, _, _),
Step<_1, _1, X>{}); // (BLK_M,BLK_N)
mSFD = make_tensor(
make_gmem_ptr<TSFD>(reinterpret_cast<TSFD *>(reinterpret_cast<char *>(SFA_COLWISE) +
offsets[group_idx] / kNVFP4BlockSize)),
sfd_layout);
gSFD_mn = local_tile(mSFD, epilogue_tiler, make_coord(_, _, _),
Step<_1, _1, X>{}); // (BLK_M,BLK_N)
gD_mn_view = tiled_divide(gD_mn, take<0, 2>(epilogue_tiler));
}
int group_start_offset = offsets[group_idx] / M;
int local_tile_n_idx =
(global_tile_n_offset - group_start_offset) / size<1>(epilogue_tiler);
Tensor tDgD_mn = gD_mn_view(_, _, _, scheduler.tile_m(), local_tile_n_idx);
Tensor tDgSFD_mn = gSFD_mn(_, _, scheduler.tile_m(), local_tile_n_idx);
accumulator_pipeline.consumer_wait(accumulator_pipe_consumer_state);
auto Acc = bulk_tmem_epilogue(_, _, _, accumulator_pipe_consumer_state.index());
Tensor tDtAcc = thr_t2r.partition_S(Acc); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDgD = thr_t2r.partition_D(tDgD_mn); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tTR_rAcc =
make_tensor<ElementAccumulator>(shape(tDgD)); // ((TMEM_LOAD,#TMEM_LOAD),MMA_M,MMA_N)
Tensor tDrD = make_tensor<TD>(shape(tDgD));
Tensor tTR_rAcc_frag =
recast<cutlass::Array<ElementAccumulator, FragmentSize>>(coalesce(tTR_rAcc));
Tensor tDrD_frag = recast<cutlass::Array<TD, FragmentSize>>(coalesce(tDrD));
Tensor src = thr_r2g.retile_S(tDrD);
Tensor dst = thr_r2g.retile_D(tDgD);
Tensor tDgSFD_view = make_tensor(
tDgSFD_mn.data(), make_layout(make_shape(shape(tDgSFD_mn), Int<1>{}, Int<1>{}),
make_stride(stride(tDgSFD_mn), Int<0>{}, Int<0>{})));
Tensor tDgSFD = filter(thr_t2r.partition_D(tDgSFD_view));
Tensor tDrSFD = make_tensor<TSFD>(shape(tDgSFD));
static int constexpr NumVecs = size(tDgD) / VectorSize;
Tensor tD_rRowSFD_frg = recast<cutlass::Array<TSFD, NumVecs>>(tDrSFD);
// Compute amax and quantization scales for this tile
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>,
true>
amax_reduction;
cutlass::Array<ElementAccumulator, NumVecs> vec_maxs;
cutlass::Array<ElementAccumulator, NumVecs> pvscales;
// Copy from TMEM to registers
copy(tiled_t2r, tDtAcc, tTR_rAcc);
cutlass::arch::fence_view_async_tmem_load();
accumulator_pipeline.consumer_release(accumulator_pipe_consumer_state);
++accumulator_pipe_consumer_state;
if constexpr (!kUseFastMath) {
// Downcast to BF16 for bit-wise compatibility with
// unfused kernels
auto convert_accum_to_bf16 =
cutlass::NumericArrayConverter<cutlass::bfloat16_t, ElementAccumulator,
FragmentSize>{};
auto convert_bf16_to_accum =
cutlass::NumericArrayConverter<ElementAccumulator, cutlass::bfloat16_t,
FragmentSize>{};
tTR_rAcc_frag(_0{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_0{})));
tTR_rAcc_frag(_1{}) = convert_bf16_to_accum(convert_accum_to_bf16(tTR_rAcc_frag(_1{})));
}
auto compute_frgs = reinterpret_cast<cutlass::Array<ElementAccumulator, VectorSize> *>(
tTR_rAcc_frag.data());
auto output_frgs = reinterpret_cast<cutlass::Array<TD, VectorSize> *>(tDrD_frag.data());
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < NumVecs; v++) {
vec_maxs[v] = amax_reduction(ElementAccumulator(0), compute_frgs[v]);
}
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
vec_maxs, global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales =
cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(vec_maxs, fp4_max);
pvscales = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
pvscales, global_encode_scale);
}
auto pvscales_cvted =
cutlass::NumericArrayConverter<TSFD, ElementAccumulator, NumVecs>{}(pvscales);
tD_rRowSFD_frg(_0{}) = pvscales_cvted;
auto qpvscale_ups = cutlass::NumericArrayConverter<ElementAccumulator, TSFD, NumVecs>{}(
tD_rRowSFD_frg(_0{}));
auto qpvscale_scaled = cutlass::multiplies<cutlass::Array<ElementAccumulator, NumVecs>>{}(
qpvscale_ups, global_decode_scale);
cutlass::Array<ElementAccumulator, NumVecs> acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales =
cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<cutlass::Array<ElementAccumulator, NumVecs>>{}(
1.0, qpvscale_scaled);
}
// Prepare stochastic rounding random state if enabled
uint4 random_uint4 = uint4{0, 0, 0, 0};
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
// "Prefetch" a stochastic rounding state for the first tile
if constexpr (kEnableStochasticRounding) {
const size_t rng_sequence = global_thread_idx + k_tile * 512 +
scheduler.get_linear_tile_idx() * K_TILE_MAX * 512;
rng.init(rng_seed, rng_sequence, rng_offset);
}
CUTLASS_PRAGMA_UNROLL
// Apply round/quantize to each fragment, with or without stochastic rounding
for (int v = 0; v < NumVecs; v++) {
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(
acc_scales[v], cutlass::platform::numeric_limits<ElementAccumulator>::max());
if constexpr (kEnableStochasticRounding) {
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v], acc_scale),
*reinterpret_cast<cutlass::Array<uint32_t, 4> *>(&random_uint4));
} else {
output_frgs[v] = cutlass::NumericArrayConverter<TD, ElementAccumulator, VectorSize>{}(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs[v], acc_scale));
}
}
// Write quantized FP4 tile and dequant scale to gmem
copy(tiled_r2g, src, dst);
copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, tDrSFD, tDgSFD);
}
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
}
} else if (is_epilogue_row_quant_warp) {
// Warp responsible for quantizing the input (before Hadamard transform) to FP4 for row-wise usage.
cutlass::arch::warpgroup_reg_alloc<136>();
if constexpr (kEnableRowQuant) {
using S2RVectorType = uint128_t;
int global_thread_idx = threadIdx.x;
int local_thread_idx = global_thread_idx % 256;
size_t rng_seed = 0;
size_t rng_offset = 0;
// g2s load all global_a_amax for all groups/tensors
CUTLASS_PRAGMA_NO_UNROLL
for (int g = local_thread_idx; g < num_tensors; g += NumEpilogueRowQuantThreadCount) {
shared_storage.global_a_amax[g] = __ldg(reinterpret_cast<float *>(amax_rowwise + g));
}
// RNG for stochastic rounding
if constexpr (kEnableStochasticRounding) {
rng_seed = rng_state != nullptr ? __ldg(rng_state) : 0;
rng_offset = rng_state != nullptr ? __ldg(rng_state + 1) : 0;
}
// Input/output tensors/partitions for row quant warp
Tensor mQA =
make_tensor(cute::subbyte_iterator<TQA>(QA), make_layout(make_shape(M, packed_N), dQA));
Tensor gQA_mn = local_tile(mQA, epilogue_tiler, make_coord(_, _, _), Step<_1, X, _1>{});
Tensor mSFA = make_tensor(make_gmem_ptr(SFA), sfa_layout);
Tensor gSFA_mn = local_tile(mSFA, epilogue_tiler, make_coord(_, _, _),
Step<_1, X, _1>{}); // (BLK_M,BLK_N)
// Swizzled shared memory A tile, with layout
Tensor sA = as_position_independent_swizzle_tensor(group_modes<0, 2>(
coalesce(make_tensor(make_smem_ptr(shared_storage.tensors.smem_A.data()),
sAlayout)))); // (BLOCK_M, BLOCK_M,PIPE)
// Set up layouts for partitioning – tile-by-warp, with vector granularity
using S2RWarpLayout = Layout<Shape<_8, _4>>;
using WarpGroupLayout = Layout<Shape<_1, _8>>;
using S2RThreadLayout = decltype(blocked_product(S2RWarpLayout{}, WarpGroupLayout{}));
using S2RValLayout = Layout<Shape<Int<VectorSize>, _1>>;
using S2RAtomA = Copy_Atom<AutoVectorizingCopy, TA>;
using R2GAtomQA = Copy_Atom<AutoVectorizingCopy, TQA>;
using R2GAtomSFA = Copy_Atom<AutoVectorizingCopy, TSFA>;
auto tiled_s2r = make_tiled_copy(S2RAtomA{}, S2RThreadLayout{}, S2RValLayout{});
auto tiled_r2g_QA = make_tiled_copy(R2GAtomQA{}, S2RThreadLayout{}, S2RValLayout{});
auto tiled_r2g_SFA = make_tiled_copy(R2GAtomSFA{}, S2RThreadLayout{}, S2RValLayout{});
auto thr_s2r = tiled_s2r.get_slice(local_thread_idx);
auto thr_r2g_QA = tiled_r2g_QA.get_slice(local_thread_idx);
auto thr_r2g_SFA = tiled_r2g_SFA.get_slice(local_thread_idx);
Tensor tQAsA = thr_s2r.partition_S(sA); // (Copy, Copy_M, Copy_N, PIPE)
// Allocate temporary register tensors for copying quantization => output
Tensor tQArA = make_tensor_like<TA>(
make_layout(tQAsA(_, _, _, _0{}).shape())); // (Copy, Copy_M, Copy_N)
Tensor tQAgQA = thr_r2g_QA.partition_S(gQA_mn);
Tensor tQArQA = make_tensor_like(tQAgQA(_, _, _, _0{}, _0{}));
Tensor tQAgSFA = thr_r2g_SFA.partition_S(gSFA_mn);
Tensor tQArSFA = make_tensor_like(tQAgSFA(_, _, _, _0{}, _0{}));
// Will result in barrier_id=10 passed to bar.sync instr as cutlass adds 8
// in order to go over the reserved named barrier count.
constexpr int row_quant_barrier_id = 2;
cutlass::arch::NamedBarrier::sync(NumEpilogueRowQuantThreadCount, row_quant_barrier_id);
int group_idx = get_current_tensor_id(shape_rep, num_tensors,
(scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M,
packed_N, M, offsets);
float a_global_amax_val = shared_storage.global_a_amax[group_idx];
// Aligning with TensorEngine's recipe to generate scale factors // {$nv-internal-release}
static constexpr float fp4_max = 6.0f;
static constexpr float fp8_max = 448.0f;
static constexpr float fp4_max_inv = 1.0f / fp4_max;
float global_encode_scale = a_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / a_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
float global_decode_scale = 1.0f / global_encode_scale;
float global_encode_scale_multiplier = 1.0f;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
auto sfa_converter = cutlass::NumericConverter<TSFA, ElementAccumulator>{};
do {
CUTLASS_PRAGMA_NO_UNROLL
for (int k_tile = 0;
k_tile < K_TILE_MAX && k_tile + scheduler.tile_n_base() < scheduler.tiles_n();) {
int global_tile_n_offset = (scheduler.tile_n_base() + k_tile) * size<1>(epilogue_tiler);
int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors,
global_tile_n_offset * M, packed_N, M, offsets);
if (cur_group_idx != group_idx) {
group_idx = cur_group_idx;
a_global_amax_val = shared_storage.global_a_amax[group_idx];
// Update group quantization parameters/scaling
global_encode_scale = a_global_amax_val > 0.0f
? cutlass::minimum_with_nan_propagation<float>{}(
(fp8_max * fp4_max) / a_global_amax_val,
cutlass::platform::numeric_limits<float>::max())
: 1.0f;
global_decode_scale = 1.0f / global_encode_scale;
if constexpr (kUseFastMath) {
global_encode_scale_multiplier = global_encode_scale * fp4_max_inv;
}
}
auto tQAgSFA_mn = tQAgSFA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile);
auto tQAgQA_mn = tQAgQA(_, _, _, scheduler.tile_m(), scheduler.tile_n_base() + k_tile);
auto barrier_token = mainloop_pipeline.consumer_try_wait(mainloop_pipe_consumer_state);
mainloop_pipeline.consumer_wait(mainloop_pipe_consumer_state, barrier_token);
copy(tiled_s2r, tQAsA(_, _, _, mainloop_pipe_consumer_state.index()), tQArA);
cutlass::arch::fence_view_async_shared();
mainloop_pipeline.consumer_release(mainloop_pipe_consumer_state);
++mainloop_pipe_consumer_state;
++k_tile;
// static int constexpr NumVecs = size(tQArA) / VectorSize;
cutlass::maximum_absolute_value_reduction<cutlass::Array<ElementAccumulator, VectorSize>,
true>
amax_reduction;
auto compute_frgs = reinterpret_cast<cutlass::Array<TA, VectorSize> *>(tQArA.data());
auto output_frgs =
reinterpret_cast<cutlass::Array<TQA, VectorSize> *>(raw_pointer_cast(tQArQA.data()));
Tensor amax =
make_tensor<ElementAccumulator>(prepend(take<1, rank(tQArA)>(tQArA.shape()), _1{}));
Tensor pvscales = make_tensor_like<ElementAccumulator>(amax);
transformer_engine::curanddx::detail::philox4x32_native_state<10> rng;
if constexpr (kEnableStochasticRounding) {
const size_t rng_sequence = global_thread_idx + k_tile * 512 +
scheduler.get_linear_tile_idx() * K_TILE_MAX * 512 +
tiles_in_m * tiles_in_n * K_TILE_MAX * 512;
rng.init(rng_seed, rng_sequence, rng_offset);
}
CUTLASS_PRAGMA_UNROLL
for (int v = 0; v < size<1>(group_modes<1, rank(tQArA)>(tQArA)); v++) {
auto amax_view = group_modes<1, rank(amax)>(amax);
auto pvscales_view = group_modes<1, rank(pvscales)>(pvscales);
auto compute_frgs_up =
cutlass::NumericArrayConverter<ElementAccumulator, TA, VectorSize>{}(
compute_frgs[v]);
amax_view(_0{}, v) = amax_reduction(ElementAccumulator(0), compute_frgs_up);
if constexpr (kUseFastMath) {
// Fast math: multiply with precomputed reciprocal
pvscales_view(_0{}, v) = cutlass::multiplies<ElementAccumulator>{}(
amax_view(_0{}, v), global_encode_scale_multiplier);
} else {
// Accurate math: perform division
pvscales_view(_0{}, v) =
cutlass::divides<ElementAccumulator>{}(amax_view(_0{}, v), fp4_max);
pvscales_view(_0{}, v) = cutlass::multiplies<ElementAccumulator>{}(
pvscales_view(_0{}, v), global_encode_scale);
}
filter(tQArSFA)(v) = sfa_converter(pvscales_view(_0{}, v));
auto qpvscale_ups =
cutlass::NumericConverter<ElementAccumulator, TSFA>{}(filter(tQArSFA)(v));
auto qpvscale_scaled =
cutlass::multiplies<ElementAccumulator>{}(qpvscale_ups, global_decode_scale);
ElementAccumulator acc_scales;
if constexpr (kUseFastMath) {
// Fast math: compute approximate reciprocal
acc_scales =
cutlass::reciprocal_approximate_ftz<decltype(qpvscale_scaled)>{}(qpvscale_scaled);
} else {
// Accurate math: compute reciprocal with division
acc_scales = cutlass::divides<ElementAccumulator>{}(1.0, qpvscale_scaled);
}
auto acc_scale = cutlass::minimum_with_nan_propagation<ElementAccumulator>{}(
acc_scales, cutlass::platform::numeric_limits<ElementAccumulator>::max());
uint4 random_uint4 = uint4{0, 0, 0, 0};
if constexpr (kEnableStochasticRounding) {
random_uint4 = rng.generate4();
output_frgs[v] = StochasticNumericConverter(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs_up, acc_scale),
*reinterpret_cast<cutlass::Array<uint32_t, 4> *>(&random_uint4));
} else {
output_frgs[v] =
cutlass::NumericArrayConverter<TQA, ElementAccumulator, VectorSize>{}(
cutlass::multiplies<cutlass::Array<ElementAccumulator, VectorSize>>{}(
compute_frgs_up, acc_scale));
}
}
copy(tiled_r2g_QA, tQArQA, tQAgQA_mn);
copy(tiled_r2g_SFA, filter(tQArSFA), filter(tQAgSFA_mn));
}
// scheduler.advance();
scheduler.fetch_next_work(sched_pipeline, sched_pipeline_consumer_state);
++sched_pipeline_consumer_state;
scheduler.update_work_tile_info();
} while (scheduler.is_valid());
}
} else {
cutlass::arch::warpgroup_reg_dealloc<32>();
}
} // NOLINT(readability/fn_size)
template <bool kEnableStochasticRounding, bool kEnableRHTColQuant, bool kEnableRowQuant,
bool kEnableSwizzleSFOutput, class TA, class TB, class TQA, class TSFA, class TD = TQA,
class TSFD = TSFA, bool kUseFastMath = false>
void group_row_col_rht_gemm_ntt_w_sfc_graph_safe(
int packed_sequence_length, int hidden_size, size_t num_tensors, ShapeRepresentation shape_rep,
TA const *A, TB const *B, TQA *QA, TSFA *SFA, TQA *QA_COLWISE, TSFA *SFA_COLWISE,
float *amax_rowwise, float *amax_colwise, const int64_t *offsets, const int64_t *first_dims,
const size_t *rng_state, uint32_t *tile_scheduler_workspace, uint32_t sm_count,
cudaStream_t stream, int k_tile_size = 1024) {
using namespace cute;
static int constexpr SFVecSize = 16;
static int constexpr RhtTensorSize = 16;
static_assert(RhtTensorSize == 16, "RhtTensorSize must be 16");
using LinearSFALayout = decltype(make_layout(make_shape(make_shape(Int<SFVecSize>{}, 0), 0),
make_stride(make_stride(_0{}, _1{}), 0)));
using LinearSFDLayout = decltype(make_layout(make_shape(0, make_shape(Int<SFVecSize>{}, 0)),
make_stride(0, make_stride(_0{}, _1{}))));
using SwizzledSFALayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::MN>::SfAtom;
using SwizzledSFDLayoutAtom =
cutlass::detail::Sm1xxBlockScaledOutputConfig<SFVecSize, UMMA::Major::K>::SfAtom;
using SwizzledSFALayout = decltype(tile_to_shape(
SwizzledSFALayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{}));
using SwizzledSFDLayout = decltype(tile_to_shape(
SwizzledSFDLayoutAtom{}, make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{}));
using SFALayout = cute::conditional_t<kEnableSwizzleSFOutput, SwizzledSFALayout, LinearSFALayout>;
using SFDLayout = cute::conditional_t<kEnableSwizzleSFOutput, SwizzledSFDLayout, LinearSFDLayout>;
SFALayout sfa_layout;
SFDLayout sfd_layout;
if constexpr (kEnableSwizzleSFOutput) {
sfa_layout = tile_to_shape(SwizzledSFALayoutAtom{},
make_shape(hidden_size, packed_sequence_length), Step<_1, _2>{});
sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{},
make_shape(hidden_size, packed_sequence_length), Step<_2, _1>{});
} else {
sfa_layout = make_layout(
make_shape(make_shape(Int<SFVecSize>{}, hidden_size / SFVecSize), packed_sequence_length),
make_stride(make_stride(_0{}, _1{}), hidden_size / SFVecSize));
sfd_layout = make_layout(
make_shape(hidden_size, make_shape(Int<SFVecSize>{}, packed_sequence_length / SFVecSize)),
make_stride(packed_sequence_length / SFVecSize, make_stride(_0{}, _1{})));
}
// Define shapes (dynamic)
auto M = hidden_size;
auto N = packed_sequence_length;
Tensor tensorA = make_tensor(A, make_shape(hidden_size, packed_sequence_length), LayoutLeft{});
Tensor tensorB = make_tensor(B, make_shape(RhtTensorSize, RhtTensorSize), LayoutLeft{});
Tensor tensorQA = make_tensor(QA, make_shape(hidden_size, packed_sequence_length), LayoutLeft{});
Tensor tensorSFA = make_tensor(SFA, sfa_layout);
// Define strides (from tensors)
auto dA = stride(tensorA); // (dM,dK)
auto dB = stride(tensorB); // (dN,dK)
auto dD = LayoutRight{}; // (dM,dN)
auto dQA = stride(tensorQA); // (dM,dK)
using ClusterShape = Shape<_1, _1, _1>;
auto cluster_shape = ClusterShape{};
auto cluster_tile_shape = Shape<_128, Int<RhtTensorSize>, Int<RhtTensorSize>>{};
auto cluster_tile_mainloop = Shape<_128, Int<RhtTensorSize>, _128>{};
// Each mainloop / epilogue loads 128 x 64 tiles while each MMA proceeds with 128 x 16 tiles
static int constexpr EpilogueUnrollFactor =
size<2>(cluster_tile_mainloop) / size<2>(cluster_tile_shape);
// Construct the MMA
auto mma = make_tiled_mma(
SM100_MMA_F16BF16_SS<TA, TB, float, size<0>(cluster_tile_shape), size<1>(cluster_tile_shape),
UMMA::Major::MN, UMMA::Major::MN>{},
Layout<Shape<_1, _1>>{});
// Assert that the TiledMMA uses all CTAs in the CGA.
CUTE_STATIC_ASSERT_V(size(cluster_shape) == size(mma));
CUTE_STATIC_ASSERT_V(evenly_divides(cluster_tile_shape, tile_shape(mma)));
// Determine the A and B shapes
auto mma_shape_B =
partition_shape_B(mma, make_shape(size<1>(cluster_tile_shape), size<2>(cluster_tile_shape)));
using TiledMma = decltype(mma);
using AtomThrID = typename TiledMma::AtomThrID;
using SmemShape_M = decltype(shape_div(
shape<0>(cluster_tile_shape),
shape_div(shape<0>(cluster_tile_shape), size<0>(cluster_tile_shape) / size(AtomThrID{}))));
using SmemShape_N = decltype(shape_div(
shape<1>(cluster_tile_shape),
shape_div(shape<1>(cluster_tile_shape), size<1>(cluster_tile_shape) / size(AtomThrID{}))));
using SmemShape_K = decltype(cute::get<2>(cluster_tile_shape));
using SmemLayoutAtomB =
decltype(cutlass::gemm::collective::detail::sm100_smem_selector<cute::UMMA::Major::MN, TB,
SmemShape_N, SmemShape_K>());
auto mma_shape_A = partition_shape_A(
mma, make_shape(size<0>(cluster_tile_mainloop), size<2>(cluster_tile_mainloop)));
using SmemShape_M_A =
decltype(shape_div(shape<0>(cluster_tile_mainloop),
shape_div(shape<0>(cluster_tile_mainloop),
size<0>(cluster_tile_mainloop) / size(AtomThrID{}))));
using SmemShape_K_A = decltype(cute::get<2>(cluster_tile_mainloop));
using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::sm100_smem_selector<
cute::UMMA::Major::MN, TA, SmemShape_M_A, SmemShape_K_A>());
static uint32_t constexpr TotalTmemRows = 128;
static uint32_t constexpr Sm100TmemCapacityColumns = 512;
static uint32_t constexpr TotalTmem = TotalTmemRows * Sm100TmemCapacityColumns;
static uint32_t constexpr AccumulatorPipelineStageCount =
TotalTmem / (cute::size<0>(cluster_tile_shape) * cute::size<1>(cluster_tile_shape));
// Define the smem layouts (static)
// Calculate max pipeline stages based on Blackwell SM100's 232KB shared memory
constexpr int SchedulerPipelineStageCount = 4;
static int constexpr MainloopPipelineBytes = sizeof(
typename cutlass::detail::CustomizedPipelineTmaUmmaAsync<1, Shape<_1, _1, _1>,
Shape<_1, _1, _1>>::SharedStorage);
static int constexpr SchedulerWorkspaceBytes = sizeof(int) * SchedulerPipelineStageCount;
static int constexpr SchedulerThrottlePipelineBytes =
sizeof(typename cutlass::PipelineAsync<SchedulerPipelineStageCount>::SharedStorage);
static int constexpr SchedulerPipelineBytes =
sizeof(typename cutlass::PipelineCLCFetchAsync<SchedulerPipelineStageCount,
ClusterShape>::SharedStorage);
static int constexpr TmemDeallocBytes = sizeof(cutlass::arch::ClusterBarrier);
static int constexpr BTensorBytes = cute::size(mma_shape_B) * sizeof(TB);
static int constexpr AccPipelineBytes = sizeof(
typename cutlass::PipelineUmmaAsync<AccumulatorPipelineStageCount / EpilogueUnrollFactor,
Shape<_1, _1, _1>>::SharedStorage);
static int constexpr TmemBasePtrsBytes = sizeof(uint32_t);
static int constexpr kBlackwellSmemSize = 232448; // 232KB in bytes
static int constexpr kBytesPerStage =
cute::size(mma_shape_A) * sizeof(TA) + MainloopPipelineBytes;
static int constexpr kReservedBytes = SchedulerWorkspaceBytes + SchedulerThrottlePipelineBytes +
SchedulerPipelineBytes + TmemBasePtrsBytes +
TmemDeallocBytes + BTensorBytes +
AccPipelineBytes; // Reserve for barriers and other uses
static int constexpr kMaxStages = (kBlackwellSmemSize - kReservedBytes) / kBytesPerStage;
auto sP = Int<kMaxStages>{}; // SMEM pipelines
auto sA = UMMA::tile_to_mma_shape(SmemLayoutAtomA{}, append(mma_shape_A, sP),
Step<_2, _1, _3>{}); // (MMA,MMA_M,MMA_K,PIPE)
auto sB = UMMA::tile_to_mma_shape(SmemLayoutAtomB{},
append(mma_shape_B, _1{})); // (MMA,MMA_N,MMA_K, _1)
auto sD = Layout<_1>{}; // XXX Dummy
auto tma_load_a =
make_tma_copy_A_sm100(SM90_TMA_LOAD{}, tensorA, sA(_, _, _, 0), cluster_tile_mainloop, mma);
auto tma_load_b =
make_tma_copy_B_sm100(SM90_TMA_LOAD{}, tensorB, sB(_, _, _, 0), cluster_tile_shape, mma);
// Assert checks on tile sizes -- no predication
assert(M % size<0>(cluster_tile_shape) == 0);
assert(N % size<1>(cluster_tile_shape) == 0);
dim3 dimBlock(512);
dim3 dimCluster(size<0>(cluster_shape), size<1>(cluster_shape), size<2>(cluster_shape));
dim3 dimGrid(sm_count, 1, 1);
int smem_size = sizeof(
SharedStorage<TA, TB, decltype(sA), decltype(sB), ClusterShape, AccumulatorPipelineStageCount,
EpilogueUnrollFactor, SchedulerPipelineStageCount>);
auto *kernel_ptr = &group_row_col_rht_gemm_device_graph_safe<
decltype(M), decltype(N), decltype(k_tile_size), decltype(cluster_shape),
decltype(cluster_tile_shape), TA, decltype(dA), decltype(sA), decltype(tma_load_a), TB,
decltype(dB), decltype(sB), decltype(tma_load_b), TD, decltype(dD), decltype(sD), TSFD,
decltype(sfd_layout), TQA, decltype(dQA), TSFA, decltype(sfa_layout), decltype(mma),
AccumulatorPipelineStageCount, SchedulerPipelineStageCount, kEnableStochasticRounding,
kEnableRHTColQuant, kEnableRowQuant, kEnableSwizzleSFOutput, kUseFastMath>;
NVTE_CHECK_CUDA(
cudaFuncSetAttribute(*kernel_ptr, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// Set workspace and set to zero
NVTE_CHECK_CUDA(cudaMemsetAsync(reinterpret_cast<void *>(tile_scheduler_workspace), 0,
sizeof(uint32_t), stream));
// Launch kernel
cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size, stream};
cutlass::Status status = cutlass::launch_kernel_on_cluster(
params, (void const *)kernel_ptr, M, N, k_tile_size, cluster_shape, cluster_tile_shape, A, dA,
sA, tma_load_a, B, dB, sB, tma_load_b, QA, dQA, SFA, sfa_layout, QA_COLWISE, SFA_COLWISE,
amax_rowwise, amax_colwise, offsets, first_dims, num_tensors, shape_rep,
tile_scheduler_workspace, mma, rng_state);
NVTE_CHECK_CUDA(cudaGetLastError());
NVTE_CHECK(status == cutlass::Status::kSuccess, "Kernel launch failed.");
}
} // namespace
} // namespace detail
void group_hadamard_transform_cast_fusion_graph_safe(const GroupedTensor *input,
GroupedTensor *output,
const Tensor &hadamard_matrix_,
QuantizationConfig &quant_config,
Tensor &quant_workspace, cudaStream_t stream) {
NVTE_API_CALL(group_hadamard_transform_cast_fusion_graph_safe);
using transformer_engine::detail::kMaxTensorsPerKernel;
using transformer_engine::detail::ShapeRepresentation;
void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr);
// TODO(zhongbo): add input sanity checks here
bool all_has_row_quant = output->has_data();
bool all_has_col_quant = output->has_columnwise_data();
// Stochastic rounding config
const bool use_stochastic_rounding = quant_config.stochastic_rounding;
const size_t *rng_state = nullptr;
if (use_stochastic_rounding) {
NVTE_CHECK(quant_config.rng_state != nullptr,
"Enabled stochastic rounding without providing RNG state");
const Tensor &rng_state_tensor = *convertNVTETensorCheck(quant_config.rng_state);
NVTE_CHECK(rng_state_tensor.dtype() == DType::kInt64,
"RNG state should contain 2 64-bit values.");
NVTE_CHECK(rng_state_tensor.data.shape == std::vector<size_t>{2},
"Shape of the RNG state should be [2], but got ", rng_state_tensor.data.shape);
rng_state = reinterpret_cast<const size_t *>(rng_state_tensor.data.dptr);
}
uint32_t *tile_scheduler_workspace = nullptr;
NVTE_CHECK(quant_workspace.data.dptr != nullptr, "Quantization workspace must be provided.");
NVTE_CHECK(quant_workspace.data.buffer_size_bytes() >= sizeof(uint32_t),
"Quantization workspace must be at least 4 bytes.");
tile_scheduler_workspace = reinterpret_cast<uint32_t *>(quant_workspace.data.dptr);
// Template arguments
using TA = cute::bfloat16_t;
using TB = cute::bfloat16_t;
using TD = cutlass::float_e2m1_t;
using TSFD = cutlass::float_ue4m3_t;
using TQA = TD;
using TSFA = TSFD;
checkCuDriverContext(stream);
// Check Hadamard matrix
constexpr int kHadamardDimension = 16;
NVTE_CHECK(hadamard_matrix_.dtype() == transformer_engine::DType::kBFloat16,
"Hadamard matrix must be BF16 tensor, but dtype is ",
to_string(hadamard_matrix_.dtype()), ".");
const SimpleTensor &hadamard_matrix = hadamard_matrix_.data;
NVTE_CHECK(
(hadamard_matrix_.shape() == std::vector<size_t>{kHadamardDimension, kHadamardDimension}),
"Hadamard matrix must have shape=",
std::vector<size_t>{kHadamardDimension, kHadamardDimension},
", but got shape=", hadamard_matrix_.shape(), ".");
const size_t hadamard_dimension = hadamard_matrix.shape[0];
const size_t num_tensors = input->num_tensors;
const size_t first_logical_dim = input->logical_shape.data[0];
const size_t last_logical_dim = input->logical_shape.data[1];
// const size_t elts_total = first_logical_dim * last_logical_dim;
NVTE_CHECK(first_logical_dim % 128 == 0,
"First dimension of a grouped tensor should be divisible by 128.");
NVTE_CHECK(last_logical_dim % 128 == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
NVTE_CHECK(num_tensors <= kMaxTensorsPerKernel,
"Number of tensors should be less than or equal to ", kMaxTensorsPerKernel);
ShapeRepresentation shape_rep = ShapeRepresentation::VARYING_FIRST_DIM;
if (output->all_same_shape()) {
shape_rep = ShapeRepresentation::SAME_BOTH_DIMS;
} else if (output->all_same_first_dim()) {
shape_rep = ShapeRepresentation::VARYING_LAST_DIM;
} else if (output->all_same_last_dim()) {
shape_rep = ShapeRepresentation::VARYING_FIRST_DIM;
} else if (output->varying_both_dims()) {
shape_rep = ShapeRepresentation::VARYING_BOTH_DIMS;
}
TQA *const rowwise_data_base_ptr = reinterpret_cast<TQA *>(output->data.dptr);
TSFA *const rowwise_scale_inv_base_ptr = reinterpret_cast<TSFA *>(output->scale_inv.dptr);
TQA *const colwise_data_base_ptr = reinterpret_cast<TQA *>(output->columnwise_data.dptr);
TSFA *const colwise_scale_inv_base_ptr =
reinterpret_cast<TSFA *>(output->columnwise_scale_inv.dptr);
float *const amax_rowwise_base_ptr = reinterpret_cast<float *>(output->amax.dptr);
float *const amax_colwise_base_ptr = reinterpret_cast<float *>(output->columnwise_amax.dptr);
const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(input->tensor_offsets.dptr);
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(input->first_dims.dptr);
// const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
const bool is_const_last_dim = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS ||
shape_rep == ShapeRepresentation::VARYING_FIRST_DIM);
NVTE_CHECK(is_const_last_dim,
"Currently we only support const last dimension for graph safe hadamard transform.");
auto sm_count = transformer_engine::cuda::sm_count();
int k_tile_size = 1024;
const bool use_swizzle_sf_output = false;
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_stochastic_rounding, kEnableStochasticRounding,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
all_has_col_quant, kEnableRhtColQuant,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
all_has_row_quant, kEnableRowQuant,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
use_swizzle_sf_output, kEnableSwizzleSFOutput,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
quant_config.use_fast_math, kUseFastMath,
if constexpr (kEnableRhtColQuant || kEnableRowQuant) {
detail::group_row_col_rht_gemm_ntt_w_sfc_graph_safe<
kEnableStochasticRounding, kEnableRhtColQuant, kEnableRowQuant,
kEnableSwizzleSFOutput, TA, TB, TQA, TSFA, TD, TSFD, kUseFastMath>(
/*packed_sequence_length=*/first_logical_dim,
/*hidden_size=*/last_logical_dim,
/*num_tensors=*/num_tensors,
/*shape_rep=*/shape_rep,
/*A=*/reinterpret_cast<TA const *>(input_base_ptr),
/*B=*/reinterpret_cast<TB const *>(hadamard_matrix.dptr),
/*QA=*/reinterpret_cast<TQA *>(rowwise_data_base_ptr),
/*SFA=*/reinterpret_cast<TSFA *>(rowwise_scale_inv_base_ptr),
/*QA_COLWISE=*/reinterpret_cast<TQA *>(colwise_data_base_ptr),
/*SFA_COLWISE=*/reinterpret_cast<TSFA *>(colwise_scale_inv_base_ptr),
/*amax_rowwise=*/reinterpret_cast<float *>(amax_rowwise_base_ptr),
/*amax_colwise=*/reinterpret_cast<float *>(amax_colwise_base_ptr),
/*offsets=*/offsets_ptr,
/*first_dims=*/first_dims_ptr,
/*rng_state=*/rng_state,
/*tile_scheduler_workspace=*/tile_scheduler_workspace,
/*sm_count=*/sm_count,
/*stream=*/stream, /*k_tile_size=*/k_tile_size);
} else {
NVTE_ERROR("Invalid kernel configuration (kEnableRHTColQuant=",
kEnableRhtColQuant, ", kEnableRowQuant=", kEnableRowQuant, ").");
}
);););););
}
} // namespace transformer_engine
void nvte_group_hadamard_transform_cast_fusion_graph_safe(
const NVTEGroupedTensor input, NVTEGroupedTensor output, const NVTETensor hadamard_matrix,
const NVTEQuantizationConfig quant_config, NVTETensor quant_workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_hadamard_transform_cast_fusion_graph_safe);
using namespace transformer_engine;
GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input);
GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output);
Tensor *quant_workspace_tensor = convertNVTETensorCheck(quant_workspace);
QuantizationConfig quant_config_cpp;
if (quant_config != nullptr) {
quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config);
}
if (input_tensor->num_tensors == 0) {
return;
}
// Call the multi-tensor Hadamard transform amax implementation.
group_hadamard_transform_cast_fusion_graph_safe(
input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp,
*quant_workspace_tensor, stream);
}
......@@ -31,6 +31,7 @@ extern "C" {
enum class NVTE_Activation_Type {
GELU,
GEGLU,
GLU,
SILU,
SWIGLU,
RELU,
......@@ -52,6 +53,16 @@ enum class NVTE_Activation_Type {
*/
void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_gelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -62,6 +73,16 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_silu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -72,6 +93,16 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_relu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -82,6 +113,16 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_qgelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -92,6 +133,16 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
*/
void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation of the grouped input.
* If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_srelu(const NVTEGroupedTensor input, NVTEGroupedTensor output, cudaStream_t stream);
/*! \brief Computes the GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -104,6 +155,18 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -116,6 +179,18 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the SiLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -128,6 +203,18 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -140,6 +227,18 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output
void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Quick GeLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......@@ -152,6 +251,44 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the Squared ReLU activation gradient of the grouped input.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming grouped gradient.
* \param[in] input Input grouped tensor for activation.
* \param[in,out] output Output grouped tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GLU (Gated Linear Unit) activation of the input.
* GLU(a,b) = sigmoid(a) * b
* See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083)
* and "GLU Variants Improve Transformer" (arXiv:2002.05202).
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes sigmoid(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Computes the GLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);
/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
......
......@@ -89,6 +89,17 @@ extern "C" {
*/
void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Casts input grouped tensor to MXFP8.
* The type of quantized tensor in the output depends on the scaling mode of the output
* tensor. See file level comments.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped MXFP8 tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize(const NVTEGroupedTensor input, NVTEGroupedTensor output,
cudaStream_t stream);
/*! \brief Casts input tensor to FP8/MXFP8/BlockwiseFP8, providing the option to immediately exit the kernel
* based on the value of the 'noop' tensor.
* The type of quantized tensor in the output depends on the scaling mode of the output
......@@ -132,6 +143,26 @@ void nvte_quantize_v2(const NVTETensor input, NVTETensor output,
void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor dbias,
NVTETensor workspace, cudaStream_t stream);
/*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -155,6 +186,29 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -178,6 +232,29 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of SiLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the SiLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -201,6 +278,29 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -224,6 +324,29 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Quick GeLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Quick GeLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
......@@ -247,6 +370,29 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp
NVTETensor output, NVTETensor dbias, NVTETensor workspace,
cudaStream_t stream);
/*! \brief Computes backward of Squared ReLU operation on the grouped input, then casts to FP8/MXFP8.
* Additionally, reduces the result of the Squared ReLU backward along columns.
* If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* This function produces 2 results:
* - `output` is equal to `cast(dact(input))`
* - `dbias` is equal to `reduce(dact(input), dim=1)`
*
* Calling this function with the workspace being an empty tensor will not perform the operation,
* but instead set the shape and type of the workspace tensor to the required values.
*
* \param[in] input Input grouped tensor to be cast.
* \param[in] act_input Activation input grouped tensor.
* \param[in,out] output Output grouped FP8/MXFP8 tensor.
* \param[out] dbias Result of the reduction of the input along columns.
* \param[out] workspace Workspace tensor.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor act_input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream);
/*! \brief Casts input tensor from reduced to higher precision.
* If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,
* the block dequantization (MXFP8) of the specified shape of the block will be used.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment