Unverified Commit 060811c9 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[Common] Fix checks in quantize_transpose_vector_blockwise_fp4 (#2299)



fix checks in unoptimized non-rht fp4 quantize kernel
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 6273cede
......@@ -718,13 +718,11 @@ void quantize_transpose_vector_blockwise_fp4(
// raise error if pow2_scale is true
NVTE_CHECK(!pow2_scale, "No support for pow2_scale for MXFP4 for now");
if (!return_identity && !return_transpose) {
return;
}
NVTE_CHECK(return_identity || return_transpose,
"At least one of return_identity or return_transpose must be true.");
if (use_2d_quantization && !return_identity) {
return;
}
NVTE_CHECK(return_identity || !use_2d_quantization,
"2D block quantization is only supported when return_identity is true.");
const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u;
size_t num_elements = row_length;
......@@ -777,7 +775,7 @@ void quantize_transpose_vector_blockwise_fp4(
input.dtype, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP4x2_ONLY(
output.dtype, 2, OutputType,
return_identity ? output.dtype : output_t.dtype, 2, OutputType,
dim3 grid(num_blocks_x, num_blocks_y, 1);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment