"...composable_kernel_rocm.git" did not exist on "1c9a5ff42b3df2bf481f05c2f9e9f142eabc8fd7"
Unverified Commit e779d7f3 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[quantization] Fix quantize config setting in speedup example (#4276)

parent 78ea3767
...@@ -83,16 +83,16 @@ def quantization_aware_training_example(train_loader, test_loader, device): ...@@ -83,16 +83,16 @@ def quantization_aware_training_example(train_loader, test_loader, device):
model = NaiveModel() model = NaiveModel()
configure_list = [{ configure_list = [{
'quant_types': ['weight', 'output'], 'quant_types': ['input', 'weight'],
'quant_bits': {'weight':8, 'output':8}, 'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv1'] 'op_names': ['conv1']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': {'output':8}, 'quant_bits': {'output':8},
'op_names': ['relu1'] 'op_names': ['relu1']
}, { }, {
'quant_types': ['weight', 'output'], 'quant_types': ['input', 'weight'],
'quant_bits': {'weight':8, 'output':8}, 'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv2'] 'op_names': ['conv2']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
......
...@@ -97,10 +97,10 @@ class QuantizationSpeedupTestCase(TestCase): ...@@ -97,10 +97,10 @@ class QuantizationSpeedupTestCase(TestCase):
model = BackboneModel() model = BackboneModel()
configure_list = { configure_list = {
'conv1':{'weight_bit':8, 'activation_bit':8}, 'conv1':{'weight_bits':8, 'output_bits':8},
'conv2':{'weight_bit':32, 'activation_bit':32}, 'conv2':{'weight_bits':32, 'output_bits':32},
'fc1':{'weight_bit':16, 'activation_bit':16}, 'fc1':{'weight_bits':16, 'output_bits':16},
'fc2':{'weight_bit':8, 'activation_bit':8} 'fc2':{'weight_bits':8, 'output_bits':8}
} }
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
...@@ -126,16 +126,16 @@ class QuantizationSpeedupTestCase(TestCase): ...@@ -126,16 +126,16 @@ class QuantizationSpeedupTestCase(TestCase):
model = BackboneModel() model = BackboneModel()
configure_list = [{ configure_list = [{
'quant_types': ['weight', 'output'], 'quant_types': ['input', 'weight'],
'quant_bits': {'weight':8, 'output':8}, 'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv1'] 'op_names': ['conv1']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
'quant_bits': {'output':8}, 'quant_bits': {'output':8},
'op_names': ['relu1'] 'op_names': ['relu1']
}, { }, {
'quant_types': ['weight', 'output'], 'quant_types': ['input', 'weight'],
'quant_bits': {'weight':8, 'output':8}, 'quant_bits': {'input':8, 'weight':8},
'op_names': ['conv2'] 'op_names': ['conv2']
}, { }, {
'quant_types': ['output'], 'quant_types': ['output'],
...@@ -145,8 +145,9 @@ class QuantizationSpeedupTestCase(TestCase): ...@@ -145,8 +145,9 @@ class QuantizationSpeedupTestCase(TestCase):
] ]
# finetune the model by using QAT # finetune the model by using QAT
dummy_input = torch.randn(1, 1, 28, 28)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
quantizer = QAT_Quantizer(model, configure_list, optimizer) quantizer = QAT_Quantizer(model, configure_list, optimizer, dummy_input)
quantizer.compress() quantizer.compress()
model.to(self.device) model.to(self.device)
...@@ -178,13 +179,13 @@ class QuantizationSpeedupTestCase(TestCase): ...@@ -178,13 +179,13 @@ class QuantizationSpeedupTestCase(TestCase):
model = vgg16() model = vgg16()
configure_list = { configure_list = {
'features.0':{'weight_bit':8, 'activation_bit':8}, 'features.0':{'weight_bits':8, 'output_bits':8},
'features.1':{'weight_bit':32, 'activation_bit':32}, 'features.1':{'weight_bits':32, 'output_bits':32},
'features.2':{'weight_bit':16, 'activation_bit':16}, 'features.2':{'weight_bits':16, 'output_bits':16},
'features.4':{'weight_bit':8, 'activation_bit':8}, 'features.4':{'weight_bits':8, 'output_bits':8},
'features.7':{'weight_bit':8, 'activation_bit':8}, 'features.7':{'weight_bits':8, 'output_bits':8},
'features.8':{'weight_bit':8, 'activation_bit':8}, 'features.8':{'weight_bits':8, 'output_bits':8},
'features.11':{'weight_bit':8, 'activation_bit':8} 'features.11':{'weight_bits':8, 'output_bits':8}
} }
model.to(self.device) model.to(self.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment