"...git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "29deb085f097f584223e0e276050b867577693d7"
Unverified Commit 396ae65c authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

support dp multi-gpu training for QAT quantizer (#4127)

parent 29b4d46c
......@@ -373,21 +373,22 @@ class QAT_Quantizer(Quantizer):
self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device
self.bound_model.register_buffer("steps", torch.Tensor([1]))
self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0]))
layer.module.register_buffer("scale", torch.Tensor([1.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99]))
module = layer.module
module.register_buffer("zero_point", torch.tensor([0.0]))
module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1))
module.register_buffer('weight_bits', torch.zeros(1))
if "input" in config.get("quant_types", []):
layer.module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('input_bits', torch.zeros(1))
module.register_buffer('tracked_min_input', torch.zeros(1))
module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []):
layer.module.register_buffer('output_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_output', torch.zeros(1))
module.register_buffer('output_bits', torch.zeros(1))
module.register_buffer('tracked_min_output', torch.zeros(1))
module.register_buffer('tracked_max_output', torch.zeros(1))
self.bound_model.to(device)
def _del_simulated_attr(self, module):
......@@ -479,8 +480,7 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1"
# we dont update weight in evaluation stage
if quant_start_step > self.bound_model.steps:
if quant_start_step > int(self.bound_model.steps):
return weight
if not wrapper.training:
......@@ -488,10 +488,16 @@ class QAT_Quantizer(Quantizer):
# quantize weight
rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax)
scale, zero_point = update_quantization_param(weight_bits, rmin, rmax)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight)
module.weight_bits = torch.Tensor([weight_bits])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight
return weight
......@@ -499,23 +505,30 @@ class QAT_Quantizer(Quantizer):
config = wrapper.config
module = wrapper.module
input_bits = get_bits_length(config, 'input')
module.input_bits = torch.Tensor([input_bits])
module.input_bit = torch.tensor([input_bits])
quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps:
module.tracked_min_input, module.tracked_max_input = torch.min(inputs), torch.max(inputs)
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
return inputs
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min,
module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay)
current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
scale, zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
inp = self._quantize(input_bits, module, inputs)
inp = self._dequantize(module, inp)
return inp
......@@ -528,19 +541,26 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps:
module.tracked_min_output, module.tracked_max_output = torch.min(output), torch.max(output)
if quant_start_step > int(self.bound_model.steps):
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max)
return output
# we dont update output quantization parameters in evaluation stage
if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
module.tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.scale, module.zero_point = update_quantization_param(
tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay)
tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay)
module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output)
scale, zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out)
......@@ -645,7 +665,7 @@ class QAT_Quantizer(Quantizer):
"""
override `compressor` `step` method, quantization only happens after certain number of steps
"""
self.bound_model.steps += 1
self.bound_model.steps.add_(1)
class DoReFaQuantizer(Quantizer):
......
......@@ -602,6 +602,8 @@ class Quantizer(Compressor):
"""
def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel):
model = model.module
self.identity_wrappers = []
self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input)
......@@ -892,12 +894,21 @@ class QuantGrad(torch.autograd.Function):
zero_point = wrapper.module.zero_point
else:
scale, zero_point = None, None
ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax)
ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx.scale = scale
ctx.zero_point = zero_point
ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax
return output
@classmethod
def backward(cls, ctx, grad_output):
tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables
tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment