Unverified Commit 445bbfd2 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Speedup enhancement (#2719)

parent 41312de5
...@@ -141,6 +141,14 @@ class ModelSpeedup: ...@@ -141,6 +141,14 @@ class ModelSpeedup:
""" """
for module_name, mask in self.masks.items(): for module_name, mask in self.masks.items():
_logger.debug('Start mask inference from %s', module_name) _logger.debug('Start mask inference from %s', module_name)
if module_name not in self.torch_graph.name_to_node:
# this module is not traced in the torch_graph,
# jit.trace only correctly records functions and
# modules which are not data dependent (e.g., do
# not have conditionals on data in tensors)
# so, if a node is not traced, we just skip it.
_logger.warning('%s has mask, but not found in the traced graph, just skip it.', module_name)
continue
self.infer_module_mask(module_name, None, mask=mask) self.infer_module_mask(module_name, None, mask=mask)
def replace_compressed_modules(self): def replace_compressed_modules(self):
......
...@@ -222,6 +222,7 @@ infer_from_inshape = { ...@@ -222,6 +222,7 @@ infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
...@@ -241,7 +242,8 @@ infer_from_inshape = { ...@@ -241,7 +242,8 @@ infer_from_inshape = {
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited), 'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask) 'Dropout2d': lambda module_masks, mask: dropout_inshape(module_masks, mask),
'aten::dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask)
} }
""" """
...@@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask): ...@@ -258,8 +260,14 @@ def dropout_inshape(module_masks, mask):
return module_masks.output_mask return module_masks.output_mask
# if alreay visited # if alreay visited
assert module_masks.input_mask <= mask assert module_masks.input_mask <= mask
if module_masks.input_mask == mask: # It should be the same, we pass the masks by the reference(not the value),
return None # so they acutually are two references of the same object(mask,
# module_masks.input_mask). So we should continue pass the mask
# to the following nodes even module_masks.input_mask == mask.
# if pass the mask by copy.deepcopy(), then we can stop when
# module_masks.input_mask == mask.
# if module_masks.input_mask == mask:
# return None
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return module_masks.output_mask return module_masks.output_mask
...@@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask): ...@@ -413,7 +421,8 @@ def linear_inshape(module_masks, mask):
""" """
assert isinstance(mask, CoarseMask) assert isinstance(mask, CoarseMask)
assert mask.mask_index[0] is None assert mask.mask_index[0] is None
assert module_masks.input_mask is None if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
return None return None
...@@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape): ...@@ -451,7 +460,10 @@ def view_inshape(module_masks, mask, shape):
assert mask.mask_index[0] is None assert mask.mask_index[0] is None
assert mask.mask_index[2] is None assert mask.mask_index[2] is None
assert mask.mask_index[3] is None assert mask.mask_index[3] is None
assert module_masks.input_mask is None # due to the cat operation, the same node may be
# accessed more than once
if module_masks.input_mask is not None:
assert module_masks.input_mask <= mask
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
output_cmask = CoarseMask(num_dim=2) output_cmask = CoarseMask(num_dim=2)
index = [] index = []
...@@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask): ...@@ -535,12 +547,9 @@ def relu_inshape(module_masks, mask):
The mask of its output tensor The mask of its output tensor
""" """
assert isinstance(mask, CoarseMask) assert isinstance(mask, CoarseMask)
# TODO: double check this assert, is it possible that a module is passed twice
if module_masks.input_mask is not None: if module_masks.input_mask is not None:
# check if has a mask conflict # check if has a mask conflict
assert module_masks.input_mask == mask assert module_masks.input_mask <= mask
# No need to pass the mask again
return None
# assert module_masks.input_mask is None, "A relu op can only be processed once" # assert module_masks.input_mask is None, "A relu op can only be processed once"
module_masks.set_input_mask(mask) module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
......
...@@ -145,18 +145,18 @@ class SpeedupTestCase(TestCase): ...@@ -145,18 +145,18 @@ class SpeedupTestCase(TestCase):
assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY) assert model.backbone2.fc1.in_features == int(orig_model.backbone2.fc1.in_features * SPARSITY)
def test_speedup_integration(self): def test_speedup_integration(self):
for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2']: for model_name in ['resnet18', 'squeezenet1_1', 'mobilenet_v2', 'densenet121', 'inception_v3']:
Model = getattr(models, model_name) Model = getattr(models, model_name)
net = Model(pretrained=True, progress=False).to(device) net = Model(pretrained=True, progress=False).to(device)
speedup_model = Model().to(device)
net.eval() # this line is necessary net.eval() # this line is necessary
speedup_model.eval()
# random generate the prune config for the pruner # random generate the prune config for the pruner
cfgs = generate_random_sparsity(net) cfgs = generate_random_sparsity(net)
pruner = L1FilterPruner(net, cfgs) pruner = L1FilterPruner(net, cfgs)
pruner.compress() pruner.compress()
pruner.export_model(MODEL_FILE, MASK_FILE) pruner.export_model(MODEL_FILE, MASK_FILE)
pruner._unwrap_model() pruner._unwrap_model()
speedup_model = Model().to(device)
speedup_model.eval()
state_dict = torch.load(MODEL_FILE) state_dict = torch.load(MODEL_FILE)
speedup_model.load_state_dict(state_dict) speedup_model.load_state_dict(state_dict)
zero_bn_bias(net) zero_bn_bias(net)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment