"vscode:/vscode.git/clone" did not exist on "89988ec8c2a0c3e18e63767d9df5ca8f6b8ff21c"
Unverified Commit fcfa0c3c authored by Hongbin Liu's avatar Hongbin Liu Committed by GitHub
Browse files

(Bug fix) Fix accuracy issue for blockwise scaling+E8 scale on Blackwell (#2589)



* bug fix
Signed-off-by: default avatarhongbinl <hongbinl@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/swizzle/swizzle_block_scaling.cu

Mask to 8 bits to prevent potential bit overlap
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarHongbin Liu  <lhb8125@users.noreply.github.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Update transformer_engine/common/swizzle/swizzle_block_scaling.cu
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: default avatarHongbin Liu  <lhb8125@users.noreply.github.com>

* fix bug in 2d too
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>

---------
Signed-off-by: default avatarhongbinl <hongbinl@nvidia.com>
Signed-off-by: default avatarHongbin Liu  <lhb8125@users.noreply.github.com>
Signed-off-by: default avatarHongbin Liu <hongbinl@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatargreptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
parent bd007993
......@@ -113,7 +113,8 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
}
// pack the exponent bits of the scaling factors
uint32_t packed_exponents = (sf.x >> 23) | (sf.y >> 15) | (sf.z >> 7) | (sf.w << 1);
uint32_t packed_exponents = ((sf.x >> 23) & 0xFF) | (((sf.y >> 23) & 0xFF) << 8) |
(((sf.z >> 23) & 0xFF) << 16) | (((sf.w >> 23) & 0xFF) << 24);
// partially swizzle the scaling factors
constexpr uint32_t ACTIVE_MASK = 0xFFFFFFFF; // no divergent branches
......@@ -198,8 +199,9 @@ void __global__ __launch_bounds__(WARPS_X_PER_TB* WARPS_Y_PER_TB* WARP_SIZE)
uint32_t sf = *reinterpret_cast<const uint32_t*>(warp_src);
// broadcast it to four scaling factors for 1x32 tiles
sf = (sf << 1) | (sf >> 7);
sf = sf | (sf >> 16);
// extract and broadcast the exponent byte to four bytes for E8M0 format
uint32_t exp_byte = (sf >> 23) & 0xFF;
sf = exp_byte | (exp_byte << 8) | (exp_byte << 16) | (exp_byte << 24);
// broadcast it to sixteen scaling factors for 1x32 tiles
const uint4 sf4{sf, sf, sf, sf};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment