"tests/git@developer.sourcefind.cn:OpenDAS/apex.git" did not exist on "03421e87c0aa212384abddd0a1aef44dafa7e8c3"
Unverified Commit beeea328 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

[bug bash] issue 2706 (#2818)



* bug bash

* fix one more bug.
Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>
parent 625a72d5
...@@ -286,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited): ...@@ -286,7 +286,7 @@ def cat_inshape(module_masks, mask, cat_info, last_visited):
Parameters Parameters
---------- ----------
module_masks : ModuleMasks module_masks : ModuleMasks
The ModuleMasks instance of the batchnorm2d The ModuleMasks instance of the Conv2d
mask : CoarseMask mask : CoarseMask
The mask of its input tensor The mask of its input tensor
cat_info: dict cat_info: dict
......
...@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix): ...@@ -118,11 +118,14 @@ class CatMaskPadding(MaskFix):
continue continue
# pad the mask for the non-pruned layers # pad the mask for the non-pruned layers
for layer in layers: for layer in layers:
if layer in self.masks:
continue
module = name_to_module[layer] module = name_to_module[layer]
w_shape = module.weight.data.size() w_shape = module.weight.data.size()
w_mask = torch.ones(w_shape).to(device) w_mask = torch.ones(w_shape).to(device)
b_mask = None b_mask = None
if hasattr(module, 'bias'): if hasattr(module, 'bias') and module.bias is not None:
# module.bias may be None
b_shape = module.bias.data.size() b_shape = module.bias.data.size()
b_mask = torch.ones(b_shape).to(device) b_mask = torch.ones(b_shape).to(device)
self.masks[layer] = {'weight':w_mask, 'bias':b_mask} self.masks[layer] = {'weight':w_mask, 'bias':b_mask}
......
...@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase): ...@@ -145,7 +145,7 @@ class SpeedupTestCase(TestCase):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def test_speedup_integration(self): def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']: for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'densenet169', 'inception_v3']:
Model = getattr(models, model_name) Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device) net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device) speedup_model = Model().to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment