Unverified Commit 93d51c82 authored by Oleg Goncharov's avatar Oleg Goncharov Committed by GitHub
Browse files

[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel (#2630)



* Added GEMM-ready preswizzling option
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarOleg Goncharov <ogoncharov@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent c4175fca
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include "../../util/ptx.cuh" #include "../../util/ptx.cuh"
#include "../../utils.cuh" #include "../../utils.cuh"
#include "../core/common.cuh" #include "../core/common.cuh"
#include "swizzle.cuh"
namespace transformer_engine { namespace transformer_engine {
namespace dispatch { namespace dispatch {
...@@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso ...@@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP, template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING> bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel(
const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_input_static,
const __grid_constant__ CUtensorMap tensor_map_act_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static,
...@@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ...@@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
using IType2 = typename ptx::FPx2<IType>; using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>; using OType2 = typename ptx::FPx2<OType>;
using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;
if constexpr (NO_ACTIVATIONS) { if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) { if (noop != nullptr && noop[0] == 1.0f) {
return; return;
...@@ -475,8 +478,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ...@@ -475,8 +478,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise; const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent; scales_colwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
...@@ -602,7 +611,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ...@@ -602,7 +611,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp); ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise; const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
scales_rowwise[scale_idx] = biased_exponent; scales_rowwise[scale_idx] = biased_exponent;
const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
...@@ -803,6 +819,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations ...@@ -803,6 +819,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
const dim3 grid(blocks); const dim3 grid(blocks);
const size_t block_size = THREADS_PER_CHUNK; const size_t block_size = THREADS_PER_CHUNK;
const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;
// Logical shape of a tensor with varying all dims is [1, M*K] // Logical shape of a tensor with varying all dims is [1, M*K]
if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) {
NVTE_CHECK(first_logical_dim % 128 == 0, NVTE_CHECK(first_logical_dim % 128 == 0,
...@@ -848,6 +866,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations ...@@ -848,6 +866,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
input->dtype(), IType, input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType, output->dtype(), OType,
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,
alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{}; alignas(64) CUtensorMap tensor_map_act_input{};
...@@ -857,8 +877,9 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations ...@@ -857,8 +877,9 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
constexpr size_t input_type_bit_size = TypeInfo<IType>::size; constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size; constexpr size_t output_type_bit_size = TypeInfo<OType>::size;
create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim,
BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
if constexpr (IS_DACT) { if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
...@@ -897,22 +918,26 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations ...@@ -897,22 +918,26 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
auto kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, auto kernel =
OType, true, true>; group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES>;
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: { case ScalingType::ROWWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, kernel =
OType, true, false>; group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false, WITH_GEMM_SWIZZLED_SCALES>;
break; break;
} }
case ScalingType::COLWISE: { case ScalingType::COLWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, kernel =
OType, false, true>; group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true, WITH_GEMM_SWIZZLED_SCALES>;
break; break;
} }
case ScalingType::BIDIMENSIONAL: { case ScalingType::BIDIMENSIONAL: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, kernel =
OType, true, true>; group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES>;
break; break;
} }
} }
...@@ -933,13 +958,13 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations ...@@ -933,13 +958,13 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>( update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, output_colwise_dptr, shape_rep, num_tensors, first_logical_dim,
offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr,
use_colwise_scaling, IS_DACT); use_rowwise_scaling, use_colwise_scaling, IS_DACT);
} }
NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, NVTE_CHECK_CUDA(cudaFuncSetAttribute(
dshmem_size)); kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));
kernel<<<grid, block_size, dshmem_size, stream>>>( kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
...@@ -953,6 +978,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations ...@@ -953,6 +978,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*) ); // NOLINT(*)
); // NOLINT(*)
} }
} // namespace mxfp8 } // namespace mxfp8
......
...@@ -335,6 +335,12 @@ struct GroupedTensor { ...@@ -335,6 +335,12 @@ struct GroupedTensor {
NVTEGroupedTensor nvte_tensor; NVTEGroupedTensor nvte_tensor;
/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool with_gemm_swizzled_scales = false;
GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(), : data(),
columnwise_data(), columnwise_data(),
...@@ -401,6 +407,7 @@ struct GroupedTensor { ...@@ -401,6 +407,7 @@ struct GroupedTensor {
num_tensors = 0; num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING; scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0; nvte_tensor = 0;
with_gemm_swizzled_scales = false;
} }
}; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment