Unverified Commit e779d7f3 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[quantization] Fix quantize config setting in speedup example (#4276)

parent 78ea3767
......@@ -83,16 +83,16 @@ def quantization_aware_training_example(train_loader, test_loader, device):
model = NaiveModel()
configure_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
......
......@@ -97,10 +97,10 @@ class QuantizationSpeedupTestCase(TestCase):
model = BackboneModel()
configure_list = {
'conv1':{'weight_bit':8, 'activation_bit':8},
'conv2':{'weight_bit':32, 'activation_bit':32},
'fc1':{'weight_bit':16, 'activation_bit':16},
'fc2':{'weight_bit':8, 'activation_bit':8}
'conv1':{'weight_bits':8, 'output_bits':8},
'conv2':{'weight_bits':32, 'output_bits':32},
'fc1':{'weight_bits':16, 'output_bits':16},
'fc2':{'weight_bits':8, 'output_bits':8}
}
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
......@@ -126,16 +126,16 @@ class QuantizationSpeedupTestCase(TestCase):
model = BackboneModel()
configure_list = [{
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv1']
}, {
'quant_types': ['output'],
'quant_bits': {'output':8},
'op_names': ['relu1']
}, {
'quant_types': ['weight', 'output'],
'quant_bits': {'weight':8, 'output':8},
'quant_types': ['input', 'weight'],
'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv2']
}, {
'quant_types': ['output'],
......@@ -145,8 +145,9 @@ class QuantizationSpeedupTestCase(TestCase):
]
# finetune the model by using QAT
dummy_input = torch.randn(1, 1, 28, 28)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer)
quantizer = QAT_Quantizer(model, configure_list, optimizer, dummy_input)
quantizer.compress()
model.to(self.device)
......@@ -178,13 +179,13 @@ class QuantizationSpeedupTestCase(TestCase):
model = vgg16()
configure_list = {
'features.0':{'weight_bit':8, 'activation_bit':8},
'features.1':{'weight_bit':32, 'activation_bit':32},
'features.2':{'weight_bit':16, 'activation_bit':16},
'features.4':{'weight_bit':8, 'activation_bit':8},
'features.7':{'weight_bit':8, 'activation_bit':8},
'features.8':{'weight_bit':8, 'activation_bit':8},
'features.11':{'weight_bit':8, 'activation_bit':8}
'features.0':{'weight_bits':8, 'output_bits':8},
'features.1':{'weight_bits':32, 'output_bits':32},
'features.2':{'weight_bits':16, 'output_bits':16},
'features.4':{'weight_bits':8, 'output_bits':8},
'features.7':{'weight_bits':8, 'output_bits':8},
'features.8':{'weight_bits':8, 'output_bits':8},
'features.11':{'weight_bits':8, 'output_bits':8}
}
model.to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment