Unverified Commit ebca6153 authored by Xin Yao's avatar Xin Yao Committed by GitHub
Browse files

[Common] PDL for Blockwise Quantization (#2066)



* enable PDL for blockwise qunatization kernels
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* add comment
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarXin Yao <xiny@nvidia.com>
parent ec65ba3c
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/util/cuda_runtime.h"
#include "common/util/ptx.cuh" #include "common/util/ptx.cuh"
#include "common/utils.cuh" #include "common/utils.cuh"
...@@ -167,6 +168,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -167,6 +168,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
OVecCast tmp_output_c; OVecCast tmp_output_c;
...@@ -390,6 +397,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -390,6 +397,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
} }
} }
// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's
// store to global memory.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
cudaTriggerProgrammaticLaunchCompletion();
#endif
// Step 3: Store cast output, Step 4: do transpose within thread tile // Step 3: Store cast output, Step 4: do transpose within thread tile
// Edge case: in the non-full tile case, there are three subcases // Edge case: in the non-full tile case, there are three subcases
// for full thread tile, it's the same thing here // for full thread tile, it's the same thing here
...@@ -511,6 +524,15 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -511,6 +524,15 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM);
const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -521,7 +543,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -521,7 +543,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
return_transpose, kReturnTranspose, return_transpose, kReturnTranspose,
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = const bool full_tile =
row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0;
...@@ -531,19 +552,21 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -531,19 +552,21 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
tensor_map_output_trans = tensor_map_output_trans =
get_tensor_map<OutputType>(output_t, num_rows, row_length); get_tensor_map<OutputType>(output_t, num_rows, row_length);
} }
block_scaled_cast_transpose_kernel<kReturnTranspose, float, InputType, OutputType> cudaLaunchKernelEx(&cfg,
<<<grid, THREADS_PER_BLOCK, 0, stream>>>( block_scaled_cast_transpose_kernel<kReturnTranspose, float,
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x,
tensor_map_output_trans, pow_2_scale); scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale);
} else { } else {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, cudaLaunchKernelEx(
OutputType> &cfg,
<<<grid, THREADS_PER_BLOCK, 0, stream>>>( block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float,
InputType, OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "common/common.h" #include "common/common.h"
#include "common/recipe/recipe_common.cuh" #include "common/recipe/recipe_common.cuh"
#include "common/transpose/cast_transpose.h" #include "common/transpose/cast_transpose.h"
#include "common/util/cuda_runtime.h"
#include "common/utils.cuh" #include "common/utils.cuh"
namespace transformer_engine { namespace transformer_engine {
...@@ -234,6 +235,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -234,6 +235,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
__syncthreads(); __syncthreads();
// If not return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return rowwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (!return_columnwise_gemm_ready && !return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 2: Cast and store to output_c // Step 2: Cast and store to output_c
if (return_rowwise) { if (return_rowwise) {
constexpr int r_stride = constexpr int r_stride =
...@@ -325,6 +334,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -325,6 +334,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
} }
} }
// If return columnwise, we trigger the next kernel here so that it's load from global memory
// can overlap with this kernel's return columnwise.
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if (return_columnwise_gemm_ready || return_columnwise_compact) {
cudaTriggerProgrammaticLaunchCompletion();
}
#endif
// Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t
if (return_columnwise_gemm_ready) { if (return_columnwise_gemm_ready) {
constexpr int c_stride = constexpr int c_stride =
...@@ -584,6 +601,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -584,6 +601,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
dim3 grid(num_blocks_x, num_blocks_y, 1);
cudaLaunchAttribute attribute[1];
attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attribute[0].val.programmaticStreamSerializationAllowed = 1;
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -591,29 +612,36 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -591,29 +612,36 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output.dtype, OutputType, output.dtype, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0;
TRANSFORMER_ENGINE_SWITCH_CONDITION( TRANSFORMER_ENGINE_SWITCH_CONDITION(
full_tile, kAligned, full_tile, kAligned,
size_t smem_bytes = kSMemSize * sizeof(InputType); size_t smem_bytes = kSMemSize * sizeof(InputType);
cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0};
if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >=
90) {
cfg.attrs = attribute;
cfg.numAttrs = 1;
}
// shared memory must be requested up // shared memory must be requested up
if (smem_bytes >= 48 * 1024) { if (smem_bytes >= 48 * 1024) {
cudaError_t err = cudaFuncSetAttribute( cudaError_t err = cudaFuncSetAttribute(
&block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>, &block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType>,
cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size.");
} block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType, OutputType> } cudaLaunchKernelEx(&cfg,
<<<grid, kThreadsPerBlock, smem_bytes, stream>>>( block_scaled_1d_cast_transpose_kernel<kAligned, float, InputType,
OutputType>,
reinterpret_cast<const InputType*>(input.dptr), reinterpret_cast<const InputType*>(input.dptr),
reinterpret_cast<OutputType*>(output.dptr), reinterpret_cast<OutputType*>(output.dptr),
reinterpret_cast<OutputType*>(output_t.dptr), reinterpret_cast<OutputType*>(output_t.dptr),
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, scale_stride_x, scale_stride_y, scale_t_stride_x,
columnwise_option, pow2_scale);) // kAligned scale_t_stride_y, epsilon, rowwise_option, columnwise_option,
pow2_scale);) // kAligned
) // OutputType ) // OutputType
) // InputType ) // InputType
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment