Unverified Commit f1b8cd2c authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Add the support for aten::mul operator. (#2905)

parent 986d58c1
...@@ -15,6 +15,7 @@ replace_module = { ...@@ -15,6 +15,7 @@ replace_module = {
'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask), 'AdaptiveAvgPool2d': lambda module, mask: no_replace(module, mask),
'ReLU': lambda module, mask: no_replace(module, mask), 'ReLU': lambda module, mask: no_replace(module, mask),
'ReLU6': lambda module, mask: no_replace(module, mask), 'ReLU6': lambda module, mask: no_replace(module, mask),
'Sigmoid': lambda module, mask: no_replace(module, mask),
'Linear': lambda module, mask: replace_linear(module, mask), 'Linear': lambda module, mask: replace_linear(module, mask),
'Dropout': lambda module, mask: no_replace(module, mask), 'Dropout': lambda module, mask: no_replace(module, mask),
'Dropout2d': lambda module, mask: no_replace(module, mask), 'Dropout2d': lambda module, mask: no_replace(module, mask),
......
...@@ -221,12 +221,14 @@ Infer output and weight shape of a module/function from its input shape ...@@ -221,12 +221,14 @@ Infer output and weight shape of a module/function from its input shape
infer_from_inshape = { infer_from_inshape = {
'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU': lambda module_masks, mask: relu_inshape(module_masks, mask),
'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask), 'ReLU6': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::tanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::tanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::hardtanh': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::hardtanh_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask), 'aten::relu_': lambda module_masks, mask: relu_inshape(module_masks, mask),
'aten::sigmoid': lambda module_masks, mask: relu_inshape(module_masks, mask),
'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask), 'Conv2d': lambda module_masks, mask: conv2d_inshape(module_masks, mask),
'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'MaxPool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask), 'aten::max_pool2d': lambda module_masks, mask: maxpool2d_inshape(module_masks, mask),
...@@ -243,6 +245,10 @@ infer_from_inshape = { ...@@ -243,6 +245,10 @@ infer_from_inshape = {
'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask), 'BatchNorm2d': lambda module_masks, mask: batchnorm2d_inshape(module_masks, mask),
'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask), 'aten::add_': lambda module_masks, mask: add_inshape(module_masks, mask),
'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask), 'aten::add': lambda module_mask, mask: add_inshape(module_mask, mask),
# mul has the similar behaviour with add, they both request
# the input tesors to have the same shape
'aten::mul': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::mul_': lambda module_mask, mask: add_inshape(module_mask, mask),
'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited), 'aten::cat': lambda module_mask, mask, cat_info, last_visited: cat_inshape(module_mask, mask, cat_info, last_visited),
'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape), 'aten::mean': lambda module_masks, mask, shape: mean_inshape(module_masks, mask, shape),
'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask), 'Dropout': lambda module_masks, mask: dropout_inshape(module_masks, mask),
......
...@@ -284,7 +284,7 @@ class ChannelMaskConflict(MaskFix): ...@@ -284,7 +284,7 @@ class ChannelMaskConflict(MaskFix):
ori_channels = w_shape[0] ori_channels = w_shape[0]
for i in channel_remain: for i in channel_remain:
mask['weight'][i] = torch.ones(w_shape[1:]) mask['weight'][i] = torch.ones(w_shape[1:])
if hasattr(mask, 'bias'): if 'bias' in mask and mask['bias'] is not None:
mask['bias'][i] = 1 mask['bias'][i] = 1
_logger.info(','.join(dset)) _logger.info(','.join(dset))
_logger.info('Pruned Filters after fixing conflict:') _logger.info('Pruned Filters after fixing conflict:')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment