Unverified Commit bfab8c67 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common] PDL for Quantization Kernels (#2001)



* PDL for MXFP8 Quantize
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent c5ee5fd0
...@@ -203,6 +203,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ...@@ -203,6 +203,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for the data to have arrived // Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity); ptx::mbarrier_wait_parity(&mbar[stage], parity);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if (stage == STAGES - 1) {
cudaTriggerProgrammaticLaunchCompletion();
}
float thread_amax = 0.0f; float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) { if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
...@@ -1121,6 +1126,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1121,6 +1126,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0};
// This kernel will only be called on sm100+, so no need to check sm_arch
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute;
cfg.numAttrs = 1;
switch (scaling_type) { switch (scaling_type) {
case ScalingType::ROWWISE: case ScalingType::ROWWISE:
cudaFuncSetAttribute( cudaFuncSetAttribute(
...@@ -1128,13 +1140,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1128,13 +1140,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK> false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
scale_stride_colwise);
break; break;
case ScalingType::COLWISE: case ScalingType::COLWISE:
cudaFuncSetAttribute( cudaFuncSetAttribute(
...@@ -1142,13 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1142,13 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false, cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK> true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
scale_stride_colwise);
break; break;
case ScalingType::BIDIMENSIONAL: case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute( cudaFuncSetAttribute(
...@@ -1156,13 +1168,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, ...@@ -1156,13 +1168,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true, cudaLaunchKernelEx(
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK> &cfg,
<<<grid, block_size, dshmem_size, stream>>>( cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
scale_stride_colwise);
break; break;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment