Unverified Commit 6ea001cf authored by Vensen's avatar Vensen Committed by GitHub
Browse files

[Bugfix][Quantization] Ensure input contiguity in per_token_quant_int8 (#31637)


Signed-off-by: default avatarvensen <vensenmu@gmail.com>
parent 1c46dea0
......@@ -122,15 +122,17 @@ def _per_token_quant_int8(
def per_token_quant_int8(x):
original_shape = x.shape
if x.dim() > 2:
x = x.view(-1, original_shape[-1])
M = x.numel() // x.shape[-1]
N = x.shape[-1]
x_q = torch.empty_like(x, device=x.device, dtype=torch.int8)
scales = torch.empty(x.shape[:-1] + (1,), device=x.device, dtype=torch.float32)
x_q = torch.empty((M, N), device=x.device, dtype=torch.int8)
scales = torch.empty((M, 1), device=x.device, dtype=torch.float32)
BLOCK = triton.next_power_of_2(N)
# heuristics for number of warps
num_warps = min(max(BLOCK // 256, 1), 8)
assert x.is_contiguous()
x = x.contiguous()
_per_token_quant_int8[(M,)](
x,
x_q,
......@@ -142,7 +144,8 @@ def per_token_quant_int8(x):
num_warps=num_warps,
num_stages=1,
)
x_q = x_q.view(*original_shape)
scales = scales.view(*original_shape[:-1], 1)
return x_q, scales
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment