Unverified Commit f3516c28 authored by Ke Bao's avatar Ke Bao Committed by GitHub
Browse files

Fix quant kernel accuracy issue (#2865)

parent 17de02f9
......@@ -22,7 +22,8 @@ def _per_token_quant_int8(
x = tl.load(x_ptr + row_id * stride_x + cols, mask=mask, other=0.0).to(tl.float32)
absmax = tl.maximum(tl.max(tl.abs(x)), 1e-10)
scale_x = absmax / 127
x_q = tl.extra.cuda.libdevice.round(x / scale_x).to(tl.int8)
x_q = x * (127 / absmax)
x_q = tl.extra.cuda.libdevice.round(x_q).to(tl.int8)
tl.store(xq_ptr + row_id * stride_xq + cols, x_q, mask=mask)
tl.store(scale_ptr + row_id, scale_x)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment