Unverified Commit 300a59c4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

Avoid division by zero in cache DS MLA kernel (#26174)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent d76541a6
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
#include <algorithm> #include <algorithm>
#include <cassert> #include <cassert>
#include <cfloat> // FLT_MIN #include <cfloat>
#ifdef USE_ROCM #ifdef USE_ROCM
#include <hip/hip_bf16.h> #include <hip/hip_bf16.h>
...@@ -479,6 +479,7 @@ __global__ void concat_and_cache_ds_mla_kernel( ...@@ -479,6 +479,7 @@ __global__ void concat_and_cache_ds_mla_kernel(
// Compute the scale for the tile // Compute the scale for the tile
float tile_scale = max_abs / 448.f; float tile_scale = max_abs / 448.f;
tile_scale = fmaxf(tile_scale, FLT_MIN);
// The first lane of each half-warp writes the scale to kv_cache // The first lane of each half-warp writes the scale to kv_cache
if ((lane_idx == 0) || (lane_idx == 16)) { if ((lane_idx == 0) || (lane_idx == 16)) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment