"...src/static/git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "45236e18dd70927e5a9e5b59d5fade9f4dd687d0"
Unverified Commit 93c53503 authored by chenbohua3's avatar chenbohua3 Committed by GitHub
Browse files

should use .data in quantization wrapper (#4113)

parent da294255
......@@ -529,13 +529,13 @@ class QuantizerModuleWrapper(torch.nn.Module):
else:
self.module.register_parameter('old_weight', torch.nn.Parameter(self.module.weight))
delattr(self.module, 'weight')
self.module.register_buffer('weight', self.module.old_weight)
self.module.register_buffer('weight', self.module.old_weight.data)
# for batch normalization folding
if self.bn_module is not None:
if _check_bias(self.module):
self.module.register_parameter('old_bias', torch.nn.Parameter(self.module.bias))
init_tensor = self.module.old_bias
init_tensor = self.module.old_bias.data
else:
init_tensor = torch.zeros_like(self.bn_module.weight)
delattr(self.module, 'bias')
......
......@@ -305,6 +305,33 @@ class CompressorTestCase(TestCase):
self.assertTrue(calibration_config is not None)
self.assertTrue(len(calibration_config) == 4)
def test_torch_quantizer_weight_type(self):
quantizer_list = [
torch_quantizer.QAT_Quantizer,
torch_quantizer.LsqQuantizer,
torch_quantizer.ObserverQuantizer,
torch_quantizer.NaiveQuantizer,
torch_quantizer.DoReFaQuantizer]
for quantizer_type in quantizer_list:
model = TorchModel().eval()
config_list = [{
'quant_types': ['weight'],
'quant_bits': 8,
'op_types': ['Conv2d', 'Linear']
}]
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
dummy = torch.randn(1, 1, 28, 28)
if quantizer_type == torch_quantizer.QAT_Quantizer:
quantizer_type(model, config_list, optimizer, dummy_input=dummy)
else:
quantizer_type(model, config_list, optimizer)
self.assertFalse(isinstance(model.conv1.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.conv2.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.fc1.module.weight, torch.nn.Parameter))
self.assertFalse(isinstance(model.fc2.module.weight, torch.nn.Parameter))
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment