Unverified Commit 05d3b7b5 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

[PyTorch] Fix normalization+amax forward CS fusion to work for untuned kernels (#2061)



* Compute amax in normalization forward in current scaling in untuned kernels
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6a4e871e
...@@ -215,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -215,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m; const int row = cta_row + warp_m;
...@@ -283,14 +284,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -283,14 +284,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
} }
// Apply fp8 factors // Apply fp8 factors
if (params.fp8_out) { if (params.fp8_out || requires_amax) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) { if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt]; compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij)); amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale; if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
} }
} }
} }
...@@ -302,17 +305,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne ...@@ -302,17 +305,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne
} }
} }
// Finalize fp8 factors // Reduce amax over block
if (params.fp8_out) { if (requires_amax) {
// Reduce amax over block amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (params.amax != nullptr) { if (threadIdx.x == 0) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); static_assert(std::is_same<compute_t, float>::value);
if (threadIdx.x == 0) { atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
} }
}
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
...@@ -205,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -205,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
scale = *reinterpret_cast<compute_t *>(params.scale); scale = *reinterpret_cast<compute_t *>(params.scale);
} }
compute_t amax = 0; compute_t amax = 0;
const bool requires_amax = params.amax != nullptr;
for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) {
const int row = cta_row + warp_m; const int row = cta_row + warp_m;
...@@ -258,14 +259,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -258,14 +259,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
} }
// Apply fp8 factors // Apply fp8 factors
if (params.fp8_out) { if (params.fp8_out || requires_amax) {
#pragma unroll #pragma unroll
for (int jt = 0; jt < NUM_ELTS; jt++) { for (int jt = 0; jt < NUM_ELTS; jt++) {
if (col + jt < params.cols) { if (col + jt < params.cols) {
compute_t z_ij = z.data.elt[jt]; compute_t z_ij = z.data.elt[jt];
__builtin_assume(amax >= 0); __builtin_assume(amax >= 0);
amax = fmaxf(amax, fabsf(z_ij)); amax = fmaxf(amax, fabsf(z_ij));
z.data.elt[jt] = z_ij * scale; if (params.fp8_out) {
z.data.elt[jt] = z_ij * scale;
}
} }
} }
} }
...@@ -277,17 +280,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ ...@@ -277,17 +280,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_
} }
} }
// Finalize fp8 factors // Reduce amax over block
if (params.fp8_out) { if (requires_amax) {
// Reduce amax over block amax = reduce_max<WARPS_M * WARPS_N>(amax, warp);
if (params.amax != nullptr) { if (threadIdx.x == 0) {
amax = reduce_max<WARPS_M * WARPS_N>(amax, warp); static_assert(std::is_same<compute_t, float>::value);
if (threadIdx.x == 0) { atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
static_assert(std::is_same<compute_t, float>::value);
atomicMaxFloat(reinterpret_cast<compute_t *>(params.amax), amax);
}
} }
}
if (params.fp8_out) {
// Update scale-inverse // Update scale-inverse
if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) {
reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale); reciprocal<compute_t>(reinterpret_cast<compute_t *>(params.scale_inv), scale);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment