Commit 847a461b authored by Yu Cheng's avatar Yu Cheng Committed by LeiWang1999
Browse files

[Bugfix] Fix X_amax Correctness Issue in Group Cast FP8 (#345)

- Modified the `group_per_split_token_cast_to_fp8` function to include a conditional check for batch sizes, ensuring that the scaling factor is applied only when within the valid range. This change enhances the robustness of the FP8 conversion process for grouped per-split tokens.
parent 70546adc
......@@ -44,7 +44,8 @@ def group_per_split_token_cast_to_fp8(M, M_max, N, BG, blk_m):
T.reduce_absmax(y_local, y_amax_local, dim=1)
for i in T.Parallel(blk_m):
y_amax_local[i] = T.max(y_amax_local[i], 1e-4)
y_s_local[i] = y_amax_local[i] / fp8_max
y_s_local[i] = T.if_then_else(row * blk_m + i < batch_sizes[bg],
y_amax_local[i] / fp8_max, 0)
for i, j in T.Parallel(blk_m, group_size):
y_q_local[i, j] = T.clamp(y_local[i, j] / y_s_local[i], fp8_min, fp8_max)
T.copy(y_q_local, y_q_local_fp8)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment