Unverified Commit fc0ff8ce authored by Dalong's avatar Dalong Committed by GitHub
Browse files

fix checkpoint load error and stop updating paramters in evaluation stage (#3124)

parent 62d5812d
...@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax):
---------- ----------
bits : int bits : int
quantization bits length quantization bits length
rmin : float rmin : Tensor
min value of real value min value of real value
rmax : float rmax : Tensor
max value of real value max value of real value
Returns Returns
...@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax): ...@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0. # extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly # Otherwise, we would not meet the requirement that 0 be an exactly
# representable value. # representable value.
rmin = min(rmin, 0) if rmin.is_cuda:
rmax = max(rmax, 0) rmin = torch.min(rmin, torch.Tensor([0]).cuda())
rmax = torch.max(rmax, torch.Tensor([0]).cuda())
qmin = torch.Tensor([0]).cuda()
qmax = torch.Tensor([(1 << bits) - 1]).cuda()
else:
rmin = torch.min(rmin, torch.Tensor([0]))
rmax = torch.max(rmax, torch.Tensor([0]))
qmin = torch.Tensor([0])
qmax = torch.Tensor([(1 << bits) - 1])
# the min and max quantized values, as floating-point values
qmin = 0
qmax = (1 << bits) - 1
# First determine the scale. # First determine the scale.
scale = (rmax - rmin) / (qmax - qmin) scale = (rmax - rmin) / (qmax - qmin)
...@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer): ...@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d' types of nn.module you want to apply quantization, eg. 'Conv2d'
""" """
super().__init__(model, config_list, optimizer) super().__init__(model, config_list, optimizer)
self.steps = 1
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None) layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", None) layer.module.register_buffer("scale", torch.Tensor([1.0]))
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1)) layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
...@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer): ...@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer):
quantization bits length quantization bits length
op : torch.nn.Module op : torch.nn.Module
target module target module
real_val : float real_val : Tensor
real value to be quantized real value to be quantized
Returns Returns
------- -------
float Tensor
""" """
if real_val.is_cuda:
op.zero_point = op.zero_point.cuda()
op.scale = op.scale.cuda()
transformed_val = op.zero_point + real_val / op.scale transformed_val = op.zero_point + real_val / op.scale
qmin = 0 qmin = 0
qmax = (1 << bits) - 1 qmax = (1 << bits) - 1
...@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer): ...@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1" assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps: # we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps or not wrapper.training:
return weight return weight
# if bias exists, quantize bias to uint32 # if bias exists, quantize bias to uint32
...@@ -258,15 +268,17 @@ class QAT_Quantizer(Quantizer): ...@@ -258,15 +268,17 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps: if quant_start_step > self.bound_model.steps:
return output return output
current_min, current_max = torch.min(output), torch.max(output) # we dont update output quantization parameters in evaluation stage
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min, if wrapper.training:
module.ema_decay, self.steps) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max, module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps) module.ema_decay, self.bound_model.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max) module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.bound_model.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
return out return out
...@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer): ...@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer):
""" """
override `compressor` `step` method, quantization only happens after certain number of steps override `compressor` `step` method, quantization only happens after certain number of steps
""" """
self.steps += 1 self.bound_model.steps +=1
class DoReFaQuantizer(Quantizer): class DoReFaQuantizer(Quantizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment