Unverified Commit 58996f35 authored by rasmith's avatar rasmith Committed by GitHub
Browse files

[AMD][Kernel][BugFix] Use correct scale in concat_and_cache_ds_mla_kernel when on gfx942 (#32976)


Signed-off-by: default avatarRandall Smith <ransmith@amd.com>
Signed-off-by: default avatarRandall Smith <Randall.Smith@amd.com>
Co-authored-by: default avatarRandall Smith <ransmith@amd.com>
parent b539f988
...@@ -24,6 +24,12 @@ ...@@ -24,6 +24,12 @@
typedef __hip_bfloat16 __nv_bfloat16; typedef __hip_bfloat16 __nv_bfloat16;
#endif #endif
#if defined(__gfx942__)
constexpr float kFp8ScaleDivisor = 224.f;
#else
constexpr float kFp8ScaleDivisor = 448.f;
#endif
void swap_blocks(torch::Tensor& src, torch::Tensor& dst, void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes, int64_t block_size_in_bytes,
const torch::Tensor& block_mapping) { const torch::Tensor& block_mapping) {
...@@ -401,8 +407,7 @@ __global__ void concat_and_cache_ds_mla_kernel( ...@@ -401,8 +407,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
} }
// Compute the scale for the tile // Compute the scale for the tile
float tile_scale = max_abs / 448.f; float tile_scale = fmaxf(max_abs / kFp8ScaleDivisor, FLT_MIN);
tile_scale = fmaxf(tile_scale, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache // The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) { if ((lane_idx == 0) || (lane_idx == 16)) {
...@@ -471,11 +476,8 @@ __global__ void indexer_k_quant_and_cache_kernel( ...@@ -471,11 +476,8 @@ __global__ void indexer_k_quant_and_cache_kernel(
#endif #endif
} }
#if defined(__gfx942__) float scale = fmaxf(amax, 1e-4) / kFp8ScaleDivisor;
float scale = fmaxf(amax, 1e-4) / 224.0f;
#else
float scale = fmaxf(amax, 1e-4) / 448.0f;
#endif
if (use_ue8m0) { if (use_ue8m0) {
scale = exp2f(ceilf(log2f(scale))); scale = exp2f(ceilf(log2f(scale)));
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment