Commit 8a03ff34 authored by wenjh's avatar wenjh
Browse files

Fix vector blockwise acc problem


Signed-off-by: wenjh's avatarwenjh <wenjh@sugon.com>
parent d1bf39cf
......@@ -171,7 +171,11 @@ class BlockwiseQuantizerReference:
qx = x_tiled * scale.reshape(M, K // tile_len, 1)
qx = torch.clamp(qx, min=-dtype_max, max=dtype_max)
if quant_dtype == torch.int8:
qx = torch.round(qx)
positive_mask = qx >= 0
negative_mask = ~positive_mask
pos_part = torch.where(positive_mask, torch.floor(qx + 0.5), 0)
neg_part = torch.where(negative_mask, torch.ceil(qx - 0.5), 0)
qx = pos_part + neg_part
qx = qx.to(dtype=quant_dtype)
qx = qx.reshape(M, K)
return qx, scale_inv
......
......@@ -4,7 +4,7 @@
from typing import Tuple
import torch
from torch.utils.cpp_extension import IS_HIP_EXTENSION
def scale_from_amax_tensor(
x_dtype: torch.dtype,
......@@ -48,6 +48,10 @@ def scale_from_amax_tensor(
# No subnormals and zero.
assert (exp > -127).all()
unity = torch.tensor([1.0], device=exp.device)
if IS_HIP_EXTENSION:
host_scale = torch.ldexp(unity.cpu(), exp.cpu())
scale = host_scale.to(exp.device)
else:
torch.ldexp(unity, exp, out=scale)
# Case where amax is inf. The frexp, ldexp logic changes 0.0 scales
# Return 0.0 for 0.0 scale for consistency with non-pow2 scale
......
......@@ -187,8 +187,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK)
// Step 3: Store cast output
CType scale_data = block_tile_scale;
OType scaled_elt =
OType scaled_elt = 0;
if constexpr(std::is_same_v<OType, int8_t>) {
scaled_elt =
static_cast<OType>(lroundf(fmaxf(-127.0f, fminf(127.0f, static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data))));
}
else {
scaled_elt =
static_cast<OType>(static_cast<CType>(thrd_tile_input[i].data.elt[j]) * scale_data);
}
tmp_output_c.data.elt[j] = scaled_elt;
// Step 4: do transpose within thread tile
if constexpr (kReturnTranspose) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment