"examples/gemm/test_example_gemm.py" did not exist on "41bc15cbd38f8a9112dcdca842a9ec6ca0513e2e"
Unverified Commit 396ae65c authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

support dp multi-gpu training for QAT quantizer (#4127)

parent 29b4d46c
...@@ -373,21 +373,22 @@ class QAT_Quantizer(Quantizer): ...@@ -373,21 +373,22 @@ class QAT_Quantizer(Quantizer):
self.quant_grad = QATGrad.apply self.quant_grad = QATGrad.apply
modules_to_compress = self.get_modules_to_compress() modules_to_compress = self.get_modules_to_compress()
device = next(model.parameters()).device device = next(model.parameters()).device
self.bound_model.register_buffer("steps", torch.Tensor([1])) self.bound_model.register_buffer("steps", torch.tensor(1))
for layer, config in modules_to_compress: for layer, config in modules_to_compress:
layer.module.register_buffer("zero_point", torch.Tensor([0.0])) module = layer.module
layer.module.register_buffer("scale", torch.Tensor([1.0])) module.register_buffer("zero_point", torch.tensor([0.0]))
layer.module.register_buffer('ema_decay', torch.Tensor([0.99])) module.register_buffer("scale", torch.tensor([1.0]))
module.register_buffer('ema_decay', torch.tensor([0.99]))
if "weight" in config.get("quant_types", []): if "weight" in config.get("quant_types", []):
layer.module.register_buffer('weight_bits', torch.zeros(1)) module.register_buffer('weight_bits', torch.zeros(1))
if "input" in config.get("quant_types", []): if "input" in config.get("quant_types", []):
layer.module.register_buffer('tracked_min_input', torch.zeros(1)) module.register_buffer('input_bits', torch.zeros(1))
layer.module.register_buffer('tracked_max_input', torch.zeros(1)) module.register_buffer('tracked_min_input', torch.zeros(1))
layer.module.register_buffer('input_bits', torch.zeros(1)) module.register_buffer('tracked_max_input', torch.zeros(1))
if "output" in config.get("quant_types", []): if "output" in config.get("quant_types", []):
layer.module.register_buffer('output_bits', torch.zeros(1)) module.register_buffer('output_bits', torch.zeros(1))
layer.module.register_buffer('tracked_min_output', torch.zeros(1)) module.register_buffer('tracked_min_output', torch.zeros(1))
layer.module.register_buffer('tracked_max_output', torch.zeros(1)) module.register_buffer('tracked_max_output', torch.zeros(1))
self.bound_model.to(device) self.bound_model.to(device)
def _del_simulated_attr(self, module): def _del_simulated_attr(self, module):
...@@ -479,8 +480,7 @@ class QAT_Quantizer(Quantizer): ...@@ -479,8 +480,7 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert weight_bits >= 1, "quant bits length should be at least 1" assert weight_bits >= 1, "quant bits length should be at least 1"
# we dont update weight in evaluation stage if quant_start_step > int(self.bound_model.steps):
if quant_start_step > self.bound_model.steps:
return weight return weight
if not wrapper.training: if not wrapper.training:
...@@ -488,10 +488,16 @@ class QAT_Quantizer(Quantizer): ...@@ -488,10 +488,16 @@ class QAT_Quantizer(Quantizer):
# quantize weight # quantize weight
rmin, rmax = torch.min(weight), torch.max(weight) rmin, rmax = torch.min(weight), torch.max(weight)
module.scale, module.zero_point = update_quantization_param(weight_bits, rmin, rmax) scale, zero_point = update_quantization_param(weight_bits, rmin, rmax)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
weight = self._quantize(weight_bits, module, weight) weight = self._quantize(weight_bits, module, weight)
weight = self._dequantize(module, weight) weight = self._dequantize(module, weight)
module.weight_bits = torch.Tensor([weight_bits]) module.weight_bits = torch.Tensor([weight_bits])
# Weight can not be in-place modified, so when use torch.nn.DataParallel, this update
# will be lost after each forward process. However, this update takes effect on each
# replicated module during each forward process, which will make the quantized weight
# be used correctly.
wrapper.module.weight = weight wrapper.module.weight = weight
return weight return weight
...@@ -499,23 +505,30 @@ class QAT_Quantizer(Quantizer): ...@@ -499,23 +505,30 @@ class QAT_Quantizer(Quantizer):
config = wrapper.config config = wrapper.config
module = wrapper.module module = wrapper.module
input_bits = get_bits_length(config, 'input') input_bits = get_bits_length(config, 'input')
module.input_bits = torch.Tensor([input_bits])
module.input_bit = torch.tensor([input_bits])
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert input_bits >= 1, "quant bits length should be at least 1" assert input_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps: if quant_start_step > int(self.bound_model.steps):
module.tracked_min_input, module.tracked_max_input = torch.min(inputs), torch.max(inputs) current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input.copy_(current_min)
module.tracked_max_input.copy_(current_max)
return inputs return inputs
# we dont update output quantization parameters in evaluation stage # we dont update output quantization parameters in evaluation stage
if wrapper.training: if wrapper.training:
current_min, current_max = torch.min(inputs), torch.max(inputs) current_min, current_max = torch.min(inputs), torch.max(inputs)
module.tracked_min_input = update_ema(module.tracked_min_input, current_min, current_min = update_ema(module.tracked_min_input, current_min, module.ema_decay)
module.ema_decay) current_max = update_ema(module.tracked_max_input, current_max, module.ema_decay)
module.tracked_max_input = update_ema(module.tracked_max_input, current_max, module.tracked_min_input.copy_(current_min)
module.ema_decay) module.tracked_max_input.copy_(current_max)
module.scale, module.zero_point = update_quantization_param(
scale, zero_point = update_quantization_param(
input_bits, module.tracked_min_input, module.tracked_max_input) input_bits, module.tracked_min_input, module.tracked_max_input)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
inp = self._quantize(input_bits, module, inputs) inp = self._quantize(input_bits, module, inputs)
inp = self._dequantize(module, inp) inp = self._dequantize(module, inp)
return inp return inp
...@@ -528,19 +541,26 @@ class QAT_Quantizer(Quantizer): ...@@ -528,19 +541,26 @@ class QAT_Quantizer(Quantizer):
quant_start_step = config.get('quant_start_step', 0) quant_start_step = config.get('quant_start_step', 0)
assert output_bits >= 1, "quant bits length should be at least 1" assert output_bits >= 1, "quant bits length should be at least 1"
if quant_start_step > self.bound_model.steps: if quant_start_step > int(self.bound_model.steps):
module.tracked_min_output, module.tracked_max_output = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output.copy_(current_min)
module.tracked_max_output.copy_(current_max)
return output return output
# we dont update output quantization parameters in evaluation stage # we dont update output quantization parameters in evaluation stage
if wrapper.training: if wrapper.training:
current_min, current_max = torch.min(output), torch.max(output) current_min, current_max = torch.min(output), torch.max(output)
module.tracked_min_output = update_ema(module.tracked_min_output, current_min, tracked_min_output = update_ema(module.tracked_min_output, current_min,
module.ema_decay) module.ema_decay)
module.tracked_max_output = update_ema(module.tracked_max_output, current_max, tracked_max_output = update_ema(module.tracked_max_output, current_max,
module.ema_decay) module.ema_decay)
module.scale, module.zero_point = update_quantization_param( module.tracked_min_output.copy_(tracked_min_output)
module.tracked_max_output.copy_(tracked_max_output)
scale, zero_point = update_quantization_param(
output_bits, module.tracked_min_output, module.tracked_max_output) output_bits, module.tracked_min_output, module.tracked_max_output)
module.scale.copy_(scale)
module.zero_point.copy_(zero_point)
out = self._quantize(output_bits, module, output) out = self._quantize(output_bits, module, output)
out = self._dequantize(module, out) out = self._dequantize(module, out)
...@@ -645,7 +665,7 @@ class QAT_Quantizer(Quantizer): ...@@ -645,7 +665,7 @@ class QAT_Quantizer(Quantizer):
""" """
override `compressor` `step` method, quantization only happens after certain number of steps override `compressor` `step` method, quantization only happens after certain number of steps
""" """
self.bound_model.steps += 1 self.bound_model.steps.add_(1)
class DoReFaQuantizer(Quantizer): class DoReFaQuantizer(Quantizer):
......
...@@ -602,6 +602,8 @@ class Quantizer(Compressor): ...@@ -602,6 +602,8 @@ class Quantizer(Compressor):
""" """
def __init__(self, model, config_list, optimizer=None, dummy_input=None): def __init__(self, model, config_list, optimizer=None, dummy_input=None):
if isinstance(model, torch.nn.DataParallel):
model = model.module
self.identity_wrappers = [] self.identity_wrappers = []
self.conv_bn_patterns = {} self.conv_bn_patterns = {}
self.find_conv_bn_patterns(model, dummy_input) self.find_conv_bn_patterns(model, dummy_input)
...@@ -892,12 +894,21 @@ class QuantGrad(torch.autograd.Function): ...@@ -892,12 +894,21 @@ class QuantGrad(torch.autograd.Function):
zero_point = wrapper.module.zero_point zero_point = wrapper.module.zero_point
else: else:
scale, zero_point = None, None scale, zero_point = None, None
ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax) ctx.save_for_backward(tensor)
# Only tensors have gradients flowing back needs to be saved by save_for_backward.
# Others should directly assign to ctx.
ctx.scale = scale
ctx.zero_point = zero_point
ctx.quant_type = quant_type
ctx.qmin, ctx.qmax = qmin, qmax
return output return output
@classmethod @classmethod
def backward(cls, ctx, grad_output): def backward(cls, ctx, grad_output):
tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables tensor = ctx.saved_variables[0]
scale, zero_point = ctx.scale, ctx.zero_point
qmin, qmax = ctx.qmin, ctx.qmax
quant_type = ctx.quant_type
output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax) output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax)
return output, None, None, None return output, None, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment