Unverified Commit a3319f4f authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

[Bugfix] Enforce contiguous input for dynamic_per_token FP8/INT8 quant (#19452)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent 9d880f59
...@@ -1270,7 +1270,7 @@ def scaled_fp8_quant( ...@@ -1270,7 +1270,7 @@ def scaled_fp8_quant(
device=input.device, device=input.device,
dtype=torch.float32) dtype=torch.float32)
torch.ops._C.dynamic_per_token_scaled_fp8_quant( torch.ops._C.dynamic_per_token_scaled_fp8_quant(
output, input, scale, scale_ub) output, input.contiguous(), scale, scale_ub)
else: else:
scale = torch.zeros(1, device=input.device, dtype=torch.float32) scale = torch.zeros(1, device=input.device, dtype=torch.float32)
torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale) torch.ops._C.dynamic_scaled_fp8_quant(output, input, scale)
...@@ -1379,8 +1379,8 @@ def scaled_int8_quant( ...@@ -1379,8 +1379,8 @@ def scaled_int8_quant(
dtype=torch.float32) dtype=torch.float32)
input_azp = None if symmetric else torch.empty_like(input_scales, input_azp = None if symmetric else torch.empty_like(input_scales,
dtype=torch.int32) dtype=torch.int32)
torch.ops._C.dynamic_scaled_int8_quant(output, input, input_scales, torch.ops._C.dynamic_scaled_int8_quant(output, input.contiguous(),
input_azp) input_scales, input_azp)
return output, input_scales, input_azp return output, input_scales, input_azp
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment