Unverified Commit 05d3b7b5 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

[PyTorch] Fix normalization+amax forward CS fusion to work for untuned kernels (#2061)



* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6a4e871e
......@@ -215,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
......@@ -283,17 +284,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
// Apply fp8 factors
if (params.fp8_out) {
if (params.fp8_out || requires_amax) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
}
}
}
// Store output
Ovec z_out;
......@@ -302,10 +305,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
// Finalize fp8 factors
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
......@@ -313,6 +314,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
......@@ -205,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
scale = *reinterpret_cast<compute_t *>(params.scale);
}
compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m;
......@@ -258,17 +259,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
// Apply fp8 factors
if (params.fp8_out) {
if (params.fp8_out || requires_amax) {
#pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij));
if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
}
}
}
// Store output
Ovec z_out;
......@@ -277,10 +280,8 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
// Finalize fp8 factors
if (params.fp8_out) {
// Reduce amax over block
if (params.amax != nullptr) {
if (requires_amax) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (threadIdx.x == 0) {
static_assert(std::is_same<compute_t, float>::value);
......@@ -288,6 +289,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
}
}
if (params.fp8_out) {
// Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment