You need to sign in or sign up before continuing.
Unverified Commit 9270f9b8 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

Update confidence=8 in speed up test (#3967)

parent b65830e0
...@@ -295,7 +295,7 @@ class SpeedupTestCase(TestCase): ...@@ -295,7 +295,7 @@ class SpeedupTestCase(TestCase):
mask_out = model(dummy_input) mask_out = model(dummy_input)
model.train() model.train()
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2) ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
ms.speedup_model() ms.speedup_model()
assert model.training assert model.training
...@@ -331,7 +331,7 @@ class SpeedupTestCase(TestCase): ...@@ -331,7 +331,7 @@ class SpeedupTestCase(TestCase):
new_model = TransposeModel() new_model = TransposeModel()
state_dict = torch.load(MODEL_FILE) state_dict = torch.load(MODEL_FILE)
new_model.load_state_dict(state_dict) new_model.load_state_dict(state_dict)
ms = ModelSpeedup(new_model, dummy_input, MASK_FILE, confidence=2) ms = ModelSpeedup(new_model, dummy_input, MASK_FILE, confidence=8)
ms.speedup_model() ms.speedup_model()
zero_bn_bias(ori_model) zero_bn_bias(ori_model)
zero_bn_bias(new_model) zero_bn_bias(new_model)
...@@ -405,7 +405,7 @@ class SpeedupTestCase(TestCase): ...@@ -405,7 +405,7 @@ class SpeedupTestCase(TestCase):
if speedup_cfg is None: if speedup_cfg is None:
speedup_cfg = {} speedup_cfg = {}
ms = ModelSpeedup(speedup_model, data, ms = ModelSpeedup(speedup_model, data,
MASK_FILE, confidence=2, **speedup_cfg) MASK_FILE, confidence=8, **speedup_cfg)
ms.speedup_model() ms.speedup_model()
speedup_model.eval() speedup_model.eval()
...@@ -439,7 +439,7 @@ class SpeedupTestCase(TestCase): ...@@ -439,7 +439,7 @@ class SpeedupTestCase(TestCase):
net.eval() net.eval()
data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device) data = torch.randn(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(net, data, MASK_FILE, confidence=2) ms = ModelSpeedup(net, data, MASK_FILE, confidence=8)
ms.speedup_model() ms.speedup_model()
ms.bound_model(data) ms.bound_model(data)
...@@ -461,7 +461,7 @@ class SpeedupTestCase(TestCase): ...@@ -461,7 +461,7 @@ class SpeedupTestCase(TestCase):
pruner.compress() pruner.compress()
model(dummy_input) model(dummy_input)
pruner.export_model(MODEL_FILE, MASK_FILE) pruner.export_model(MODEL_FILE, MASK_FILE)
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=2) ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
ms.speedup_model() ms.speedup_model()
def test_finegrained_speedup(self): def test_finegrained_speedup(self):
...@@ -490,7 +490,7 @@ class SpeedupTestCase(TestCase): ...@@ -490,7 +490,7 @@ class SpeedupTestCase(TestCase):
print(model) print(model)
pruner.export_model(MODEL_FILE, MASK_FILE) pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model() pruner._unwrap_model()
ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=4) ms = ModelSpeedup(model, dummy_input, MASK_FILE, confidence=8)
ms.speedup_model() ms.speedup_model()
print("Fine-grained speeduped model") print("Fine-grained speeduped model")
print(model) print(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment