Unverified Commit 0f68f7b2 authored by Zhongbo Zhu's avatar Zhongbo Zhu Committed by GitHub
Browse files

[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119)



* add noop to comp amax
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* fix for fp8 blockwise recipe
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* resolve comments
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarzhongboz <zhongboz@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent f378eaf2
...@@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( ...@@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction(
*/ */
void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream);
/*! \brief Compute an FP8 tensor's amax with quantization config.
*
* The amax (maximum absolute value) of the input tensor is computed
* and written to the amax buffer of the output tensor, using the provided
* quantization configuration.
* One useful config is the noop tensor, which is needed by cuda graph.
*
* \param[in] input Input tensor. Must be unquantized.
* \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling.
* \param[in] config Quantization configuration.
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output,
const NVTEQuantizationConfig config, cudaStream_t stream);
/*! \brief Update an FP8 tensor's scale based on its amax. /*! \brief Update an FP8 tensor's scale based on its amax.
* *
* This is only supported for FP8 tensors with per-tensor scaling. * This is only supported for FP8 tensors with per-tensor scaling.
......
...@@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512; ...@@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512;
template <int nvec, bool aligned, typename InputType> template <int nvec, bool aligned, typename InputType>
__launch_bounds__(amax_kernel_threads) __global__ __launch_bounds__(amax_kernel_threads) __global__
void amax_kernel(const InputType *input, float *amax, const size_t N, void amax_kernel(const InputType *input, float *amax, const size_t N,
const size_t num_aligned_elements) { const size_t num_aligned_elements, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
VectorizedLoader<InputType, nvec, aligned> loader(input, N); VectorizedLoader<InputType, nvec, aligned> loader(input, N);
InputType max = 0.f; InputType max = 0.f;
const int warp_id = threadIdx.x / THREADS_PER_WARP; const int warp_id = threadIdx.x / THREADS_PER_WARP;
...@@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__ ...@@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__
} }
template <int nvec, typename InputType> template <int nvec, typename InputType>
void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr,
cudaStream_t stream) {
// Zero out amax so we can update with atomic max // Zero out amax so we can update with atomic max
NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream));
...@@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud ...@@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
switch (align) { switch (align) {
case Alignment::SAME_ALIGNED: case Alignment::SAME_ALIGNED:
amax_kernel<nvec, true, InputType> amax_kernel<nvec, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
break; break;
case Alignment::SAME_UNALIGNED: case Alignment::SAME_UNALIGNED:
amax_kernel<nvec, false, InputType> amax_kernel<nvec, false, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements); <<<num_blocks, threads, 0, stream>>>(input, amax, N, num_aligned_elements, noop_ptr);
break; break;
case Alignment::DIFFERENT: { case Alignment::DIFFERENT: {
// This case is a logic error, since there is only one pointer (input) // This case is a logic error, since there is only one pointer (input)
// in the alignment check. Still safe to process without vectorization. // in the alignment check. Still safe to process without vectorization.
amax_kernel<1, true, InputType><<<num_blocks, threads, 0, stream>>>(input, amax, N, N); amax_kernel<1, true, InputType>
<<<num_blocks, threads, 0, stream>>>(input, amax, N, N, noop_ptr);
break; break;
} }
} }
...@@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud ...@@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud
} // namespace } // namespace
} // namespace transformer_engine } // namespace transformer_engine
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { namespace {
NVTE_API_CALL(nvte_compute_amax);
void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream,
const NVTEQuantizationConfig config_) {
using namespace transformer_engine; using namespace transformer_engine;
// Check input tensor // Check input tensor
...@@ -138,12 +146,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt ...@@ -138,12 +146,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
to_string(output.amax.dtype), ")"); to_string(output.amax.dtype), ")");
CheckOutputTensor(output, "output_compute_amax", true); CheckOutputTensor(output, "output_compute_amax", true);
float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}
// Compute amax // Compute amax
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType);
launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr), launch_amax_kernel<nvec>(reinterpret_cast<const IType *>(input.data.dptr),
reinterpret_cast<float *>(output.amax.dptr), input.data.numel(), reinterpret_cast<float *>(output.amax.dptr), input.data.numel(),
stream);); // NOLINT(*) noop_ptr, stream);); // NOLINT(*)
}
} // anonymous namespace
void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax);
compute_amax_impl(input_, output_, stream, nullptr);
}
void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_,
const NVTEQuantizationConfig config_, cudaStream_t stream) {
NVTE_API_CALL(nvte_compute_amax_with_config);
compute_amax_impl(input_, output_, stream, config_);
} }
namespace transformer_engine { namespace transformer_engine {
...@@ -151,7 +182,11 @@ namespace { ...@@ -151,7 +182,11 @@ namespace {
__global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
const float max_fp8, const bool force_pow_2_scales, const float max_fp8, const bool force_pow_2_scales,
const float epsilon) { const float epsilon, const float *noop_ptr) {
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
*scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon,
std::numeric_limits<float>::max()); std::numeric_limits<float>::max());
} }
...@@ -197,10 +232,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf ...@@ -197,10 +232,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType,
max_fp8 = Quantized_Limits<DType>::max_norm;); max_fp8 = Quantized_Limits<DType>::max_norm;);
// noop tensor for cuda graph
float *noop_ptr = nullptr;
if (config_ != nullptr) {
const QuantizationConfig *config_cpp = reinterpret_cast<const QuantizationConfig *>(config_);
// extract noop tensor from quant_config_cpp if it's not null
const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr;
noop_ptr = reinterpret_cast<float *>(
(noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr));
}
// Update scale // Update scale
compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>( compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>(
reinterpret_cast<const float *>(output.amax.dptr), reinterpret_cast<const float *>(output.amax.dptr),
reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales, reinterpret_cast<float *>(output.scale.dptr), max_fp8, config.force_pow_2_scales,
config.amax_epsilon); config.amax_epsilon, noop_ptr);
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
} }
...@@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor
SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &scale_inv_t, SimpleTensor &output,
SimpleTensor &output_t, const float epsilon, SimpleTensor &output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale, const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream); const SimpleTensor &noop_tensor, cudaStream_t stream);
// enum class for rowwise usage // enum class for rowwise usage
enum class FP8BlockwiseRowwiseOption { enum class FP8BlockwiseRowwiseOption {
...@@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor ...@@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor
SimpleTensor &output_t, const float epsilon, SimpleTensor &output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scale, cudaStream_t stream); const bool pow_2_scale, const SimpleTensor &noop_tensor,
cudaStream_t stream);
} // namespace transformer_engine::detail } // namespace transformer_engine::detail
......
...@@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) ...@@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x,
const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_y, const float epsilon,
const __grid_constant__ CUtensorMap tensor_map_output_t, const __grid_constant__ CUtensorMap tensor_map_output_t,
bool pow_2_scaling) { bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>; using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>; using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>; using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are // shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce // NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
...@@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose ...@@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose
CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length,
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
bool pow_2_scaling) { bool pow_2_scaling, const float* noop_ptr) {
using IVec = Vec<IType, THREAD_TILE_DIM_X>; using IVec = Vec<IType, THREAD_TILE_DIM_X>;
using OVecCast = Vec<OType, THREAD_TILE_DIM_X>; using OVecCast = Vec<OType, THREAD_TILE_DIM_X>;
using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>; using OVecTrans = Vec<OType, THREAD_TILE_DIM_Y>;
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
// shared mem for amax reduction in entire block, each warp produces one amax, there are // shared mem for amax reduction in entire block, each warp produces one amax, there are
// NUM_WARPS_IN_BLOCK amax to reduce // NUM_WARPS_IN_BLOCK amax to reduce
__shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK];
...@@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& scale_inv_t, SimpleTensor& output,
SimpleTensor& output_t, const float epsilon, SimpleTensor& output_t, const float epsilon,
const bool return_transpose, const bool pow_2_scale, const bool return_transpose, const bool pow_2_scale,
cudaStream_t stream) { const SimpleTensor& noop_tensor, cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_square_blockwise); NVTE_API_CALL(quantize_transpose_square_blockwise);
checkCuDriverContext(stream); checkCuDriverContext(stream);
...@@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
size_t scale_t_stride_x = 0; size_t scale_t_stride_x = 0;
size_t scale_t_stride_y = 0; size_t scale_t_stride_y = 0;
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
if (return_transpose) { if (return_transpose) {
NVTE_CHECK(output_t.shape.size() == input.shape.size(), NVTE_CHECK(output_t.shape.size() == input.shape.size(),
"output_t must have same number of dimensions as input."); "output_t must have same number of dimensions as input.");
...@@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
tensor_map_output_trans, pow_2_scale); tensor_map_output_trans, pow_2_scale, noop_ptr);
} else { } else {
block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType, block_scaled_cast_transpose_kernel_notaligned<kReturnTranspose, float, InputType,
OutputType> OutputType>
...@@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows,
scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon,
pow_2_scale); pow_2_scale, noop_ptr);
} // full-tile } // full-tile
) // return_transpose ) // return_transpose
) // OutputType ) // OutputType
......
...@@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo ...@@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo
const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y,
const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow_2_scaling) { const bool pow_2_scaling, const float* noop_ptr) {
// skip execution if noop
if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) {
return;
}
bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE;
bool return_columnwise_gemm_ready = bool return_columnwise_gemm_ready =
columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
...@@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
SimpleTensor& output_t, const float epsilon, SimpleTensor& output_t, const float epsilon,
FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseRowwiseOption rowwise_option,
FP8BlockwiseColumnwiseOption columnwise_option, FP8BlockwiseColumnwiseOption columnwise_option,
const bool pow2_scale, cudaStream_t stream) { const bool pow2_scale, const SimpleTensor& noop_tensor,
cudaStream_t stream) {
NVTE_API_CALL(quantize_transpose_vector_blockwise); NVTE_API_CALL(quantize_transpose_vector_blockwise);
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
...@@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim);
const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim);
const float* noop_ptr = reinterpret_cast<const float*>(noop_tensor.dptr);
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
input.dtype, InputType, input.dtype, InputType,
...@@ -613,7 +621,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor ...@@ -613,7 +621,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor
reinterpret_cast<float*>(scale_inv.dptr), reinterpret_cast<float*>(scale_inv.dptr),
reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, reinterpret_cast<float*>(scale_inv_t.dptr), row_length, num_rows, scale_stride_x,
scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option,
columnwise_option, pow2_scale);) // kAligned columnwise_option, pow2_scale, noop_ptr);) // kAligned
) // OutputType ) // OutputType
) // InputType ) // InputType
NVTE_CHECK_CUDA(cudaGetLastError()); NVTE_CHECK_CUDA(cudaGetLastError());
......
...@@ -1427,7 +1427,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1427,7 +1427,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
quantize_transpose_square_blockwise( quantize_transpose_square_blockwise(
input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->data, output_tensor->columnwise_data, epsilon, output_tensor->data, output_tensor->columnwise_data, epsilon,
/*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales,
/*noop_tensor=*/noop_tensor.data, stream);
break; break;
} }
case NVTE_BLOCK_SCALING_1D: { case NVTE_BLOCK_SCALING_1D: {
...@@ -1455,10 +1456,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ...@@ -1455,10 +1456,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o
? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT
: FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY;
} }
quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, quantize_transpose_vector_blockwise(
output_tensor->columnwise_scale_inv, output_tensor->data, input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv,
output_tensor->columnwise_data, epsilon, rowwise_option, output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option,
columnwise_option, force_pow_2_scales, stream); columnwise_option, force_pow_2_scales, noop_tensor.data, stream);
break; break;
} }
default: default:
......
...@@ -518,7 +518,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te ...@@ -518,7 +518,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te
// Compute amax // Compute amax
if (compute_amax) { if (compute_amax) {
NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); NVTE_SCOPED_GIL_RELEASE(
{ nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); });
} }
// Perform amax reduction if needed // Perform amax reduction if needed
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment