Unverified Commit d0d40631 authored by Evgeny Tsykunov's avatar Evgeny Tsykunov Committed by GitHub
Browse files

[PyTorch] Fix amax computation using output_t data in normalization (#2355)



Fix amax computation using output_t data in normalization
Signed-off-by: default avatarEvgeny <etsykunov@nvidia.com>
parent d8f1e68f
......@@ -114,8 +114,18 @@ void compute_ref_output(NormType norm_type,
tmp = current * rsigma[i] * g;
}
// Write output (scaled only for fp8 paths)
output[i * H + j] = static_cast<OutputType>(tmp * scale);
current_max = fmaxf(current_max, fabsf(tmp));
// amax semantics:
// - fp8_out (scale != 1): amax on pre-scale compute value 'tmp'
// - non-fp8_out (scale == 1): amax on value converted to OutputType (e.g., bf16)
if (scale != 1.f) {
current_max = fmaxf(current_max, fabsf(tmp));
} else {
OutputType out_t_val = static_cast<OutputType>(tmp);
current_max = fmaxf(current_max, fabsf(static_cast<compute_t>(out_t_val)));
}
}
}
......
......@@ -123,7 +123,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel(
if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(temp_output));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(temp_output);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
}
if (params.fp8_out) {
temp_output = temp_output * scale;
......@@ -290,7 +297,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(z_ij));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(z_ij);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
......
......@@ -115,7 +115,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke
if (requires_amax) {
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(temp_output));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(temp_output));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(temp_output);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
}
if (params.fp8_out) {
temp_output = temp_output * scale;
......@@ -265,7 +272,14 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
// For fp8_out, keep amax on pre-scale compute_t
amax = fmaxf(amax, fabsf(z_ij));
} else {
// Otherwise compute amax on the value converted to output_t (e.g., bf16)
output_t out_t_val = output_t(z_ij);
amax = fmaxf(amax, fabsf(compute_t(out_t_val)));
}
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment