Unverified Commit c66b7472 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Fix bug for speedup module and enhance the Ut for speedup (#3279)



Fix bug for speedup module and enhance the Ut for speedup
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>
parent d1e94573
......@@ -891,12 +891,7 @@ def conv2d_mask(module_masks, mask):
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
if index is None:
return None, None, None
else:
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
......@@ -962,6 +957,7 @@ def conv2d_inshape(module_masks, mask):
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
......
......@@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file
assert os.path.exists(masks)
masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
......@@ -127,6 +128,7 @@ class CatMaskPadding(MaskFix):
for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
......@@ -136,6 +138,7 @@ class CatMaskPadding(MaskFix):
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks
......@@ -250,6 +253,10 @@ class ChannelMaskConflict(MaskFix):
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
device = _tmp_tensor['weight'].device
for dset in depen_sets:
if len(dset) <= 1:
continue
......@@ -301,7 +308,7 @@ class ChannelMaskConflict(MaskFix):
for i, dim_mask in enumerate(channel_masks):
if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int()
channel_masks[i] = torch.ones(num_channels).int().to(device)
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import os
import psutil
import sys
import numpy as np
import torch
......@@ -128,6 +129,18 @@ def generate_random_sparsity(model):
'sparsity': sparsity})
return cfg_list
def generate_random_sparsity_v2(model):
"""
Only select 50% layers to prune.
"""
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
if np.random.uniform(0, 1.0) > 0.5:
sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
def zero_bn_bias(model):
with torch.no_grad():
......@@ -292,10 +305,19 @@ class SpeedupTestCase(TestCase):
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1',
'mobilenet_v2', 'densenet121',
# skip this test on windows(7GB mem available) due to memory limit
# Note: hack trick, may be updated in the future
if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!')
return
Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]
for model_name in ['resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121' , 'densenet169',
# 'inception_v3' inception is too large and may fail the pipeline
'densenet169', 'resnet50']:
'resnet50']:
for gen_cfg_func in Gen_cfg_funcs:
kwargs = {
'pretrained': True
}
......@@ -305,14 +327,14 @@ class SpeedupTestCase(TestCase):
'pretrained': False,
'groups': 4
}
Model = getattr(models, model_name)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
cfgs = gen_cfg_func(net)
print("Testing {} with compression config \n {}".format(model_name, cfgs))
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
......@@ -339,6 +361,7 @@ class SpeedupTestCase(TestCase):
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def test_channel_prune(self):
orig_net = resnet18(num_classes=10).to(device)
channel_prune(orig_net)
......@@ -369,7 +392,9 @@ class SpeedupTestCase(TestCase):
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def tearDown(self):
if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE)
if os.path.exists(MASK_FILE):
os.remove(MASK_FILE)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment