Unverified Commit 93d51c82 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel (#2630)



* Added GEMM-ready preswizzling option
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c4175fca
......@@ -21,6 +21,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
#include "swizzle.cuh"
namespace transformer_engine {
namespace dispatch {
......@@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING>
bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel(
const __grid_constant__ CUtensorMap tensor_map_input_static,
const __grid_constant__ CUtensorMap tensor_map_act_input_static,
......@@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;
using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
......@@ -475,8 +478,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
......@@ -602,7 +611,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
scales_rowwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
......@@ -803,6 +819,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
const dim3 grid(blocks);
const size_t block_size = THREADS_PER_CHUNK;
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
// Logical shape of a tensor with varying all dims is [1, M*K]
if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) {
NVTE_CHECK(first_logical_dim % 128 == 0,
......@@ -848,111 +866,119 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim,
BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size);
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
auto kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false>;
break;
}
case ScalingType::COLWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
break;
}
}
// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);
const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;
OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;
OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim,
offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling,
use_colwise_scaling, IS_DACT);
}
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,
alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};
constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}
if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}
if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}
constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;
const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
auto kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::COLWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
}
// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);
const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;
OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;
OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr,
use_rowwise_scaling, use_colwise_scaling, IS_DACT);
}
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);
if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}
} // namespace mxfp8
......
......@@ -335,6 +335,12 @@ struct GroupedTensor {
NVTEGroupedTensor nvte_tensor;
/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool with_gemm_swizzled_scales = false;
GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(),
columnwise_data(),
......@@ -401,6 +407,7 @@ struct GroupedTensor {
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
with_gemm_swizzled_scales = false;
}
};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment