Commit 4cc47ca6 authored by wenjh's avatar wenjh
Browse files

Merge branch 'develop_v2.4'

parents e704bbc8 8eff19c9
......@@ -130,6 +130,12 @@ class BlockwiseQuantizerReference:
)
qx = x_tiled * scale.reshape(M // tile_len, 1, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8:
positive_mask = qx >= 0
negative_mask = ~positive_mask
pos_part = torch.where(positive_mask, torch.floor(qx + 0.5), 0)
neg_part = torch.where(negative_mask, torch.ceil(qx - 0.5), 0)
qx = pos_part + neg_part
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
if unpadded_k != K or unpadded_m != M:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment