Unverified Commit fc0ff8ce authored by Dalong's avatar Dalong Committed by GitHub
Browse files

fix checkpoint load error and stop updating paramters in evaluation stage (#3124)

parent 62d5812d
......@@ -73,9 +73,9 @@ def update_quantization_param(bits, rmin, rmax):
----------
bits : int
quantization bits length
rmin : float
rmin : Tensor
min value of real value
rmax : float
rmax : Tensor
max value of real value
Returns
......@@ -85,12 +85,17 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
rmin = min(rmin, 0)
rmax = max(rmax, 0)
if rmin.is_cuda:
rmin = torch.min(rmin, torch.Tensor([0]).cuda())
rmax = torch.max(rmax, torch.Tensor([0]).cuda())
qmin = torch.Tensor([0]).cuda()
qmax = torch.Tensor([(1 << bits) - 1]).cuda()
else:
rmin = torch.min(rmin, torch.Tensor([0]))
rmax = torch.max(rmax, torch.Tensor([0]))
qmin = torch.Tensor([0])
qmax = torch.Tensor([(1 << bits) - 1])
# the min and max quantized values, as floating-point values
qmin = 0
qmax = (1 << bits) - 1
# First determine the scale.
scale = (rmax - rmin) / (qmax - qmin)
......@@ -143,11 +148,11 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
super().__init__(model, config_list, optimizer)
self.steps = 1
modules_to_compress = self.get_modules_to_compress()
self.bound_model.register_buffer("steps", torch.Tensor([1]))
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", None)
layer.module.register_buffer("scale", None)
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
layer.module.register_buffer('tracked_min_biased', torch.zeros(1))
......@@ -187,13 +192,17 @@ class QAT_Quantizer(Quantizer):
quantization bits length
op : torch.nn.Module
target module
real_val : float
real_val : Tensor
real value to be quantized
Returns
-------
float
Tensor
"""
if real_val.is_cuda:
op.zero_point = op.zero_point.cuda()
op.scale = op.scale.cuda()
transformed_val = op.zero_point + real_val / op.scale
qmin = 0
qmax = (1 << bits) - 1
......@@ -229,7 +238,8 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps or not wrapper.training:
return weight
# if bias exists, quantize bias to uint32
......@@ -258,15 +268,17 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.steps:
if quant_start_step > self.bound_model.steps:
return output
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_biased, module.tracked_min = update_ema(module.tracked_min_biased, current_min,
module.ema_decay, self.bound_model.steps)
module.tracked_max_biased, module.tracked_max = update_ema(module.tracked_max_biased, current_max,
module.ema_decay, self.bound_model.steps)
module.scale, module.zero_point = update_quantization_param(output_bits, module.tracked_min, module.tracked_max)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
return out
......@@ -279,7 +291,7 @@ class QAT_Quantizer(Quantizer):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.steps += 1
self.bound_model.steps +=1
class DoReFaQuantizer(Quantizer):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment