Commit 8eff19c9 authored by wenjh's avatar wenjh
Browse files

Fix verify acc failed of blockwise quantizer


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent 8a03ff34
...@@ -130,6 +130,12 @@ class BlockwiseQuantizerReference: ...@@ -130,6 +130,12 @@ class BlockwiseQuantizerReference:
) )
qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1) qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max) qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8:
positive_mask = qx >= 0
negative_mask = ~positive_mask
pos_part = torch.where(positive_mask, torch.floor(qx + 0.5), 0)
neg_part = torch.where(negative_mask, torch.ceil(qx - 0.5), 0)
qx = pos_part + neg_part
qx = qx.to(dtype=quant_dtype) qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K) qx = qx.reshape(M, K)
if unpadded_k != K or unpadded_m != M: if unpadded_k != K or unpadded_m != M:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment