"docs/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "cae707299565e8fc6e0787171c3497bbc960e249"
Unverified Commit 0c13ea49 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

fix QAT ema issue and tensor type error (#3219)

parent cc58a81d
...@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer): ...@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer):
wrapper.module.weight = weight wrapper.module.weight = weight
return weight return weight
def update_ema(biased_ema, value, decay, step): def update_ema(biased_ema, value, decay):
""" """
calculate biased stat and unbiased stat in each step using exponential moving average method calculate biased stat and unbiased stat in each step using exponential moving average method
...@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step): ...@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step):
current stat value current stat value
decay : float decay : float
the weight of previous stat value, larger means smoother curve the weight of previous stat value, larger means smoother curve
step : int
current step
Returns Returns
------- -------
float, float float, float
""" """
biased_ema = biased_ema * decay + (1 - decay) * value biased_ema = biased_ema * decay + (1 - decay) * value
unbiased_ema = biased_ema / (1 - decay ** step) # Bias correction return biased_ema
return biased_ema, unbiased_ema
def update_quantization_param(bits, rmin, rmax): def update_quantization_param(bits, rmin, rmax):
...@@ -260,16 +257,17 @@ class QAT_Quantizer(Quantizer): ...@@ -260,16 +257,17 @@ class QAT_Quantizer(Quantizer):
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps: if quant_start_step > self.bound_model.steps:
module.tracked_min_biased, module.tracked_max_biased = torch.min(output), torch.max(output)
return output return output
# we dont update output quantization parameters in evaluation stage # we dont update output quantization parameters in evaluation stage
if wrapper.training: if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, module.tracked_min_biased = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.bound_model.steps) module.ema_decay)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.tracked_max_biased = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.bound_model.steps) module.ema_decay)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min_biased, module.tracked_max_biased)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
return out return out
......
...@@ -669,7 +669,7 @@ class QuantGrad(torch.autograd.Function): ...@@ -669,7 +669,7 @@ class QuantGrad(torch.autograd.Function):
raise ValueError("unrecognized QuantType.") raise ValueError("unrecognized QuantType.")
bits = QuantGrad.get_bits_length(wrapper.config, quant_type) bits = QuantGrad.get_bits_length(wrapper.config, quant_type)
qmin, qmax = torch.Tensor([0], device=tensor.device), torch.Tensor([(1 << bits) - 1], device=tensor.device) qmin, qmax = torch.Tensor([0]).to(device=tensor.device), torch.Tensor([(1 << bits)-1]).to(device=tensor.device)
ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax) ctx.save_for_backward(tensor, wrapper.module.scale, wrapper.module.zero_point, qmin, qmax)
return output return output
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment