Unverified Commit 4b1f46a3 authored by Ming-Hsuan-Tu's avatar Ming-Hsuan-Tu Committed by GitHub
Browse files

support prelu to speedup model (#3842)

parent c717ce57
...@@ -15,6 +15,7 @@ replace_module = { ...@@ -15,6 +15,7 @@ replace_module = {
'AvgPool2d': lambda module, mask: no_replace(module, mask), 'AvgPool2d': lambda module, mask: no_replace(module, mask),
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask), 'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask),
'PReLU': lambda module, mask: replace_prelu(module, mask),
'ReLU6': lambda module, mask: no_replace(module, mask), 'ReLU6': lambda module, mask: no_replace(module, mask),
'Sigmoid': lambda module, mask: no_replace(module, mask), 'Sigmoid': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask), 'Linear': lambda module, mask: replace_linear(module, mask),
...@@ -31,6 +32,31 @@ def no_replace(module, mask): ...@@ -31,6 +32,31 @@ def no_replace(module, mask):
_logger.debug("no need to replace") _logger.debug("no need to replace")
return module return module
def replace_prelu(norm, mask):
"""
Parameters
----------
norm : torch.nn.BatchNorm2d
The prelu module to be replace
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.PReLU
The new prelu module
"""
assert isinstance(mask, ModuleMasks)
assert 'weight' in mask.param_masks
index = mask.param_masks['weight'].mask_index[0]
num_features = index.size()[0]
# _logger.debug("replace prelu with num_features: %d", num_features)
if num_features == 0:
return torch.nn.Identity()
new_norm = torch.nn.PReLU(num_features)
# assign weights
new_norm.weight.data = torch.index_select(norm.weight.data, 0, index)
return new_norm
def replace_linear(linear, mask): def replace_linear(linear, mask):
""" """
......
...@@ -240,6 +240,7 @@ Infer output and weight shape of a module/function from its input shape ...@@ -240,6 +240,7 @@ Infer output and weight shape of a module/function from its input shape
infer_from_inshape = { infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'PReLU': lambda module_masks, mask: prelu_inshape(module_masks, mask),
'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask), 'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
...@@ -293,6 +294,7 @@ infer_from_outshape = { ...@@ -293,6 +294,7 @@ infer_from_outshape = {
'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask), 'AdaptiveAvgPool2d': lambda module_masks, mask: maxpool2d_outshape(module_masks, mask),
'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask), 'ReLU': lambda module_masks, mask: relu_outshape(module_masks, mask),
'PReLU': lambda module_masks, mask: prelu_outshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_outshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask), 'aten::tanh': lambda module_masks, mask: relu_outshape(module_masks, mask),
...@@ -735,6 +737,62 @@ def maxpool2d_outshape(module_masks, mask): ...@@ -735,6 +737,62 @@ def maxpool2d_outshape(module_masks, mask):
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
return mask return mask
def prelu_inshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the PReLU
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
module_masks.set_input_mask(mask)
module_masks.set_output_mask(mask)
weight_cmask = CoarseMask(num_dim=1)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
return mask
def prelu_outshape(module_masks, mask):
"""
We assume only the second dimension has coarse grained mask
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the PReLU
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert isinstance(mask, CoarseMask)
assert mask.mask_index[1] is not None
assert mask.mask_index[0] is None
assert mask.mask_index[2] is None
assert mask.mask_index[3] is None
weight_cmask = CoarseMask(num_dim=4)
weight_cmask.add_index_mask(dim=0, index=mask.mask_index[1])
module_masks.set_param_masks('weight', weight_cmask)
return mask
def relu_inshape(module_masks, mask): def relu_inshape(module_masks, mask):
""" """
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment