Unverified Commit 9520f251 authored by Dalong's avatar Dalong Committed by GitHub
Browse files

fix potential bug in quantize device (#3212)


Co-authored-by: default avataryuanxiulong <yuanxiulong@bytedance.com>
parent 5f0ac599
...@@ -85,16 +85,10 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -85,16 +85,10 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0. # extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly # Otherwise, we would not meet the requirement that 0 be an exactly
# representable value. # representable value.
if rmin.is_cuda: rmin = torch.min(rmin, torch.Tensor([0]).to(rmin.device))
rmin = torch.min(rmin, torch.Tensor([0]).cuda()) rmax = torch.max(rmax, torch.Tensor([0]).to(rmin.device))
rmax = torch.max(rmax, torch.Tensor([0]).cuda()) qmin = torch.Tensor([0]).to(rmin.device)
qmin = torch.Tensor([0]).cuda() qmax = torch.Tensor([(1 << bits) - 1]).to(rmin.device)
qmax = torch.Tensor([(1 << bits) - 1]).cuda()
else:
rmin = torch.min(rmin, torch.Tensor([0]))
rmax = torch.max(rmax, torch.Tensor([0]))
qmin = torch.Tensor([0])
qmax = torch.Tensor([(1 << bits) - 1])
# First determine the scale. # First determine the scale.
scale = (rmax - rmin) / (qmax - qmin) scale = (rmax - rmin) / (qmax - qmin)
...@@ -103,7 +97,6 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -103,7 +97,6 @@ def update_quantization_param(bits, rmin, rmax):
initial_zero_point = qmin - rmin / scale initial_zero_point = qmin - rmin / scale
# Now we need to nudge the zero point to be an integer # Now we need to nudge the zero point to be an integer
nudged_zero_point = 0
if initial_zero_point < qmin: if initial_zero_point < qmin:
nudged_zero_point = qmin nudged_zero_point = qmin
elif initial_zero_point > qmax: elif initial_zero_point > qmax:
...@@ -199,10 +192,8 @@ class QAT_Quantizer(Quantizer): ...@@ -199,10 +192,8 @@ class QAT_Quantizer(Quantizer):
------- -------
Tensor Tensor
""" """
if real_val.is_cuda: op.zero_point = op.zero_point.to(real_val.device)
op.zero_point = op.zero_point.cuda() op.scale = op.scale.to(real_val.device)
op.scale = op.scale.cuda()
transformed_val = op.zero_point + real_val / op.scale transformed_val = op.zero_point + real_val / op.scale
qmin = 0 qmin = 0
qmax = (1 << bits) - 1 qmax = (1 << bits) - 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment