Unverified Commit 0c13ea49 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

fix QAT ema issue and tensor type error (#3219)

parent cc58a81d
......@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer):
wrapper.module.weight = weight
return weight
def update_ema(biased_ema, value, decay, step):
def update_ema(biased_ema, value, decay):
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
......@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step):
current stat value
decay : float
the weight of previous stat value, larger means smoother curve
step : int
current step
Returns
-------
float, float
"""
biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction
return biased_ema, unbiased_ema
return biased_ema
def update_quantization_param(bits, rmin, rmax):
......@@ -260,16 +257,17 @@ class QAT_Quantizer(Quantizer):
assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps:
module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output)
return output
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.bound_model.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.bound_model.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min,
module.ema_decay)
module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
......
......@@ -669,7 +669,7 @@ class QuantGrad(torch.autograd.Function):
raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device)
qmin, qmax = torch.Tensor([0]).to(device=tensor.device), torch.Tensor([(1 << bits)-1]).to(device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment