Unverified Commit b8c0fb6e authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: add init file (#2063)

parent b4ab371b
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import torch
from .infer_shape import CoarseMask, ModuleMasks
from .infer_shape import ModuleMasks
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
......
......@@ -379,7 +379,7 @@ class ModelSpeedup:
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
"""
Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module:
Infer its input and output shape from its weight mask
Infer its output shape from its input shape
......
......@@ -56,7 +56,7 @@ class CoarseMask:
s.add(num)
for num in index_b:
s.add(num)
return torch.tensor(sorted(s))
return torch.tensor(sorted(s)) # pylint: disable=not-callable
def merge(self, cmask):
"""
......@@ -98,7 +98,7 @@ class ModuleMasks:
self.param_masks = dict()
self.input_mask = None
self.output_mask = None
def set_param_masks(self, name, mask):
"""
Parameters
......@@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape):
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework.
Parameters
----------
module_masks : ModuleMasks
......@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index))
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask
......@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask):
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight']
shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device)
......@@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask):
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its output tensor
Returns
-------
CoarseMask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment