Unverified Commit 6b0ecee6 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

fix compressor ut (#1997)

parent 2de52a89
...@@ -135,12 +135,12 @@ class CompressorTestCase(TestCase): ...@@ -135,12 +135,12 @@ class CompressorTestCase(TestCase):
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
masks = pruner.calc_mask(layer, config_list[0]) masks = pruner.calc_mask(layer, config_list[0], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 45., 45., 0., 0., 45., 45., 45., 45.]))
pruner.update_epoch(1) pruner.update_epoch(1)
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
masks = pruner.calc_mask(layer, config_list[1]) masks = pruner.calc_mask(layer, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.])) assert all(torch.sum(masks['weight'], (1, 2, 3)).numpy() == np.array([45., 45., 0., 0., 0., 0., 0., 0., 45., 45.]))
@tf2 @tf2
...@@ -187,9 +187,9 @@ class CompressorTestCase(TestCase): ...@@ -187,9 +187,9 @@ class CompressorTestCase(TestCase):
model.conv1.weight.data = torch.tensor(w).float() model.conv1.weight.data = torch.tensor(w).float()
model.conv2.weight.data = torch.tensor(w).float() model.conv2.weight.data = torch.tensor(w).float()
layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1) layer1 = torch_compressor.compressor.LayerInfo('conv1', model.conv1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2) layer2 = torch_compressor.compressor.LayerInfo('conv2', model.conv2)
mask2 = pruner.calc_mask(layer2, config_list[1]) mask2 = pruner.calc_mask(layer2, config_list[1], if_calculated=torch.tensor(0))
assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.])) assert all(torch.sum(mask1['weight'], (1, 2, 3)).numpy() == np.array([0., 27., 27., 27., 27.]))
assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.])) assert all(torch.sum(mask2['weight'], (1, 2, 3)).numpy() == np.array([0., 0., 0., 27., 27.]))
...@@ -215,9 +215,9 @@ class CompressorTestCase(TestCase): ...@@ -215,9 +215,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list) pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 1., 1., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 1., 1., 1., 1.]))
...@@ -229,9 +229,9 @@ class CompressorTestCase(TestCase): ...@@ -229,9 +229,9 @@ class CompressorTestCase(TestCase):
pruner = torch_compressor.SlimPruner(model, config_list) pruner = torch_compressor.SlimPruner(model, config_list)
layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1) layer1 = torch_compressor.compressor.LayerInfo('bn1', model.bn1)
mask1 = pruner.calc_mask(layer1, config_list[0]) mask1 = pruner.calc_mask(layer1, config_list[0], if_calculated=torch.tensor(0))
layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2) layer2 = torch_compressor.compressor.LayerInfo('bn2', model.bn2)
mask2 = pruner.calc_mask(layer2, config_list[0]) mask2 = pruner.calc_mask(layer2, config_list[0], if_calculated=torch.tensor(0))
assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask2['weight'].numpy() == np.array([0., 0., 0., 1., 1.]))
assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.])) assert all(mask1['bias'].numpy() == np.array([0., 0., 0., 1., 1.]))
...@@ -268,14 +268,14 @@ class CompressorTestCase(TestCase): ...@@ -268,14 +268,14 @@ class CompressorTestCase(TestCase):
# test ema # test ema
x = torch.tensor([[-0.2, 0], [0.1, 0.2]]) x = torch.tensor([[-0.2, 0], [0.1, 0.2]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.002, abs_tol=eps)
quantizer.step() quantizer.step()
x = torch.tensor([[0.2, 0.4], [0.6, 0.8]]) x = torch.tensor([[0.2, 0.4], [0.6, 0.8]])
out = model.relu(x) out = model.relu(x)
assert math.isclose(model.relu.tracked_min_biased, 0.002, abs_tol=eps) assert math.isclose(model.relu.module.tracked_min_biased, 0.002, abs_tol=eps)
assert math.isclose(model.relu.tracked_max_biased, 0.00998, abs_tol=eps) assert math.isclose(model.relu.module.tracked_max_biased, 0.00998, abs_tol=eps)
if __name__ == '__main__': if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment