"examples/elementwise/example_elementwise_add_tma_1d.py" did not exist on "bfb5b0a3ca3555dc4f78a9e6926cfd2e7df7f596"
Unverified Commit c66b7472 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Fix bug for speedup module and enhance the Ut for speedup (#3279)



Fix bug for speedup module and enhance the Ut for speedup
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>
parent d1e94573
......@@ -891,23 +891,18 @@ def conv2d_mask(module_masks, mask):
sum_idx = (1, 2, 3) if dim == 0 else (0, 2, 3)
index = torch.nonzero(weight_mask.abs().sum(
sum_idx) != 0, as_tuple=True)[0]
if len(index) == weight_mask.shape[dim]: # full mask
index = None
if index is None:
return None, None, None
else:
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
if dim == 0 and 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index = index.long().to(weight_mask.device)
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=dim, index=index)
bias_cmask = None
if dim == 0 and 'bias' in mask and mask['bias'] is not None:
bias_index = torch.nonzero(mask['bias'], as_tuple=True)[0]
assert torch.all(torch.eq(index, bias_index)), \
"bias mask should be consistent with weight mask"
bias_cmask = CoarseMask(num_dim=1)
bias_cmask.add_index_mask(dim=0, index=bias_index)
return index, weight_cmask, bias_cmask
index, weight_cmask, bias_cmask = convert_to_coarse_mask(
mask, dim=conv_prune_dim)
......@@ -962,6 +957,7 @@ def conv2d_inshape(module_masks, mask):
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert module_masks.input_mask == mask
# shape changes pass through depths wise conv layers
......
......@@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file
assert os.path.exists(masks)
masks = torch.load(masks)
assert len(masks) > 0, 'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
......@@ -127,6 +128,7 @@ class CatMaskPadding(MaskFix):
for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer]
w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device)
......@@ -136,6 +138,7 @@ class CatMaskPadding(MaskFix):
b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight': w_mask, 'bias': b_mask}
return self.masks
......@@ -250,6 +253,10 @@ class ChannelMaskConflict(MaskFix):
self.model, self.dummy_input, self.traced)
depen_sets = channel_depen.dependency_sets
sum_idx = (1, 2, 3) if self.conv_prune_dim == 0 else (0, 2, 3)
(_tmp_name, _tmp_tensor) = list(self.masks.items())[0]
device = _tmp_tensor['weight'].device
for dset in depen_sets:
if len(dset) <= 1:
continue
......@@ -301,7 +308,7 @@ class ChannelMaskConflict(MaskFix):
for i, dim_mask in enumerate(channel_masks):
if dim_mask is None:
channel_masks[i] = torch.ones(num_channels).int()
channel_masks[i] = torch.ones(num_channels).int().to(device)
# merge masks with 'or'
merged_channel_mask = channel_masks[0].clone()
......
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import os
import psutil
import sys
import numpy as np
import torch
......@@ -128,6 +129,18 @@ def generate_random_sparsity(model):
'sparsity': sparsity})
return cfg_list
def generate_random_sparsity_v2(model):
"""
Only select 50% layers to prune.
"""
cfg_list = []
for name, module in model.named_modules():
if isinstance(module, nn.Conv2d):
if np.random.uniform(0, 1.0) > 0.5:
sparsity = np.random.uniform(0.5, 0.99)
cfg_list.append({'op_types': ['Conv2d'], 'op_names': [name],
'sparsity': sparsity})
return cfg_list
def zero_bn_bias(model):
with torch.no_grad():
......@@ -292,52 +305,62 @@ class SpeedupTestCase(TestCase):
# Example: https://msrasrg.visualstudio.com/NNIOpenSource/_build/results?buildId=16282
def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1',
'mobilenet_v2', 'densenet121',
# skip this test on windows(7GB mem available) due to memory limit
# Note: hack trick, may be updated in the future
if 'win' in sys.platform or 'Win'in sys.platform:
print('Skip test_speedup_integration on windows due to memory limit!')
return
Gen_cfg_funcs = [generate_random_sparsity, generate_random_sparsity_v2]
for model_name in ['resnet18', 'mobilenet_v2', 'squeezenet1_1', 'densenet121' , 'densenet169',
# 'inception_v3' inception is too large and may fail the pipeline
'densenet169', 'resnet50']:
kwargs = {
'pretrained': True
}
if model_name == 'resnet50':
# testing multiple groups
'resnet50']:
for gen_cfg_func in Gen_cfg_funcs:
kwargs = {
'pretrained': False,
'groups': 4
'pretrained': True
}
if model_name == 'resnet50':
# testing multiple groups
kwargs = {
'pretrained': False,
'groups': 4
}
Model = getattr(models, model_name)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = gen_cfg_func(net)
print("Testing {} with compression config \n {}".format(model_name, cfgs))
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
zero_bn_bias(speedup_model)
data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model()
speedup_model.eval()
ori_out = net(data)
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):' %
model_name, ori_sum)
print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
Model = getattr(models, model_name)
net = Model(**kwargs).to(device)
speedup_model = Model(**kwargs).to(device)
net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner
cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs)
pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model()
state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict)
zero_bn_bias(net)
zero_bn_bias(speedup_model)
data = torch.ones(BATCH_SIZE, 3, 128, 128).to(device)
ms = ModelSpeedup(speedup_model, data, MASK_FILE)
ms.speedup_model()
speedup_model.eval()
ori_out = net(data)
speeded_out = speedup_model(data)
ori_sum = torch.sum(ori_out).item()
speeded_sum = torch.sum(speeded_out).item()
print('Sum of the output of %s (before speedup):' %
model_name, ori_sum)
print('Sum of the output of %s (after speedup):' %
model_name, speeded_sum)
assert (abs(ori_sum - speeded_sum) / abs(ori_sum) < RELATIVE_THRESHOLD) or \
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def test_channel_prune(self):
orig_net = resnet18(num_classes=10).to(device)
......@@ -369,8 +392,10 @@ class SpeedupTestCase(TestCase):
(abs(ori_sum - speeded_sum) < ABSOLUTE_THRESHOLD)
def tearDown(self):
os.remove(MODEL_FILE)
os.remove(MASK_FILE)
if os.path.exists(MODEL_FILE):
os.remove(MODEL_FILE)
if os.path.exists(MASK_FILE):
os.remove(MASK_FILE)
if __name__ == '__main__':
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment