Unverified Commit 0f68f7b2 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)



* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f378eaf2
......@@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Compute an FP8 tensor's amax with quantization config.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor, using the provided
* quantization configuration.
* One useful config is the noop tensor, which is needed by cuda graph.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[in] config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig config, cudaStream_t stream);
/*! \brief Update an FP8 tensor's scale based on its amax.
*
* This is only supported for FP8 tensors with per-tensor scaling.
......
......@@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512;
template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) {
const size_t num_aligned_elements, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max = 0.f;
const int warp_id = threadIdx.x / THREADS_PER_WARP;
......@@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__
}
template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) {
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
cudaStream_t stream) {
// Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
......@@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
switch (align) {
case Alignment::SAME_ALIGNED:
amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
break;
case Alignment::SAME_UNALIGNED:
amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements);
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
break;
case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization.
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N);
amax_kernel<1, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, N, noop_ptr);
break;
}
}
......@@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
} // namespace
} // namespace transformer_engine
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
namespace {
void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream,
const NVTEQuantizationConfig config_) {
using namespace transformer_engine;
// Check input tensor
......@@ -138,12 +146,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax", true);
float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}
// Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
stream);); // NOLINT(*)
noop_ptr, stream);); // NOLINT(*)
}
} // anonymous namespace
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
compute_amax_impl(input_, output_, stream, nullptr);
}
void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_,
const NVTEQuantizationConfig config_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax_with_config);
compute_amax_impl(input_, output_, stream, config_);
}
namespace transformer_engine {
......@@ -151,7 +182,11 @@ namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales,
const float epsilon) {
const float epsilon, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
std::numeric_limits<float>::max());
}
......@@ -197,10 +232,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;);
// noop tensor for cuda graph
float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}
// Update scale
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon);
config.amax_epsilon, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError());
}
......@@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream);
const SimpleTensor &noop_tensor, cudaStream_t stream);
// enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption {
......@@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
SimpleTensor &output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scale, cudaStream_t stream);
const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream);
} // namespace transformer_engine::detail
......
......@@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const size_t scale_stride_y, const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon,
const __grid_constant__ CUtensorMap tensor_map_output_t,
bool pow_2_scaling) {
bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
......@@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) {
bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
......@@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) {
const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream);
......@@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0;
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input.");
......@@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale);
tensor_map_output_trans, pow_2_scale, noop_ptr);
} else {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType>
......@@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale);
pow_2_scale, noop_ptr);
} // full-tile
) // return_transpose
) // OutputType
......
......@@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) {
const bool pow_2_scaling, const float* noop_ptr) {
// skip execution if noop
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
......@@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
SimpleTensor& output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow2_scale, cudaStream_t stream) {
const bool pow2_scale, const SimpleTensor& noop_tensor,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise);
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
......@@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType,
......@@ -613,7 +621,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);) // kAligned
columnwise_option, pow2_scale, noop_ptr);) // kAligned
) // OutputType
) // InputType
NVTE_CHECK_CUDA(cudaGetLastError());
......
......@@ -1427,7 +1427,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream);
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream);
break;
}
case NVTE_BLOCK_SCALING_1D: {
......@@ -1455,10 +1456,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
}
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv,
output_tensor->columnwise_scale_inv, output_tensor->data,
output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream);
quantize_transpose_vector_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, noop_tensor.data, stream);
break;
}
default:
......
......@@ -518,7 +518,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te
// Compute amax
if (compute_amax) {
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); });
NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); });
}
// Perform amax reduction if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment