Unverified Commit d972e76d authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

Revert "[Common] PDL for Quantization Kernels" (#2114)

Revert "[Common] PDL for Quantization Kernels (#2001)"

This reverts commit bfab8c67

.
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 3d0ea80a
......@@ -203,11 +203,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
// Wait for the data to have arrived
ptx::mbarrier_wait_parity(&mbar[stage], parity);
// Trigger the next kernel, so its TMA load can be overlapped with the current kernel
if (stage == STAGES - 1) {
cudaTriggerProgrammaticLaunchCompletion();
}
float thread_amax = 0.0f;
if constexpr (COLWISE_SCALING) {
const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise;
......@@ -1127,13 +1122,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;
cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0};
// This kernel will only be called on sm100+, so no need to check sm_arch
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute;
cfg.numAttrs = 1;
switch (scaling_type) {
case ScalingType::ROWWISE:
cudaFuncSetAttribute(
......@@ -1141,13 +1129,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::COLWISE:
cudaFuncSetAttribute(
......@@ -1155,13 +1143,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, false,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
case ScalingType::BIDIMENSIONAL:
cudaFuncSetAttribute(
......@@ -1169,13 +1157,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size);
cudaLaunchKernelEx(
&cfg,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true,
true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>,
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType, true, true,
CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>
<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr,
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise);
workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise,
scale_stride_colwise);
break;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment