Unverified Commit 2357480b authored by rivos-shreeasish's avatar rivos-shreeasish Committed by GitHub
Browse files

[BugFix] Fix UB in per_token_group_quant.cu (#24913)


Signed-off-by: default avatarShreeasish Kumar <shreeasish@rivosinc.com>
parent f11e3c51
......@@ -12,8 +12,8 @@
#include "../vectorization_utils.cuh"
#include "../../dispatch_utils.h"
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
unsigned mask = 0xffff;
__device__ __forceinline__ float GroupReduceMax(float val) {
unsigned mask = threadIdx.x % 32 >= 16 ? 0xffff0000 : 0x0000ffff;
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
......@@ -86,7 +86,7 @@ __global__ void per_token_group_quant_8bit_kernel(
threads_per_group, // stride in group
scalar_op_cache); // scalar handler
local_absmax = GroupReduceMax(local_absmax, lane_id);
local_absmax = GroupReduceMax(local_absmax);
float y_s = local_absmax / max_8bit;
if constexpr (SCALE_UE8M0) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment