Unverified Commit b8c0fb6e authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: add init file (#2063)

parent b4ab371b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import torch import torch
from .infer_shape import CoarseMask, ModuleMasks from .infer_shape import ModuleMasks
replace_module = { replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
......
...@@ -379,7 +379,7 @@ class ModelSpeedup: ...@@ -379,7 +379,7 @@ class ModelSpeedup:
def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None): def infer_module_mask(self, module_name, mask=None, in_shape=None, out_shape=None):
""" """
Infer input shape / output shape based on the module's weight mask / input shape / output shape. Infer input shape / output shape based on the module's weight mask / input shape / output shape.
For a module: For a module:
Infer its input and output shape from its weight mask Infer its input and output shape from its weight mask
Infer its output shape from its input shape Infer its output shape from its input shape
......
...@@ -56,7 +56,7 @@ class CoarseMask: ...@@ -56,7 +56,7 @@ class CoarseMask:
s.add(num) s.add(num)
for num in index_b: for num in index_b:
s.add(num) s.add(num)
return torch.tensor(sorted(s)) return torch.tensor(sorted(s)) # pylint: disable=not-callable
def merge(self, cmask): def merge(self, cmask):
""" """
...@@ -98,7 +98,7 @@ class ModuleMasks: ...@@ -98,7 +98,7 @@ class ModuleMasks:
self.param_masks = dict() self.param_masks = dict()
self.input_mask = None self.input_mask = None
self.output_mask = None self.output_mask = None
def set_param_masks(self, name, mask): def set_param_masks(self, name, mask):
""" """
Parameters Parameters
...@@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -217,7 +217,7 @@ def view_inshape(module_masks, mask, shape):
TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not TODO: consider replace tensor.view with nn.Flatten, because tensor.view is not
included in module, thus, cannot be replaced by our framework. included in module, thus, cannot be replaced by our framework.
Parameters Parameters
---------- ----------
module_masks : ModuleMasks module_masks : ModuleMasks
...@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
...@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask): ...@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask):
""" """
assert 'weight' in mask assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor) assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight'] weight_mask = mask['weight']
shape = weight_mask.size() shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device) ones = torch.ones(shape[1:]).to(weight_mask.device)
...@@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask): ...@@ -451,7 +450,7 @@ def conv2d_outshape(module_masks, mask):
The ModuleMasks instance of the conv2d The ModuleMasks instance of the conv2d
mask : CoarseMask mask : CoarseMask
The mask of its output tensor The mask of its output tensor
Returns Returns
------- -------
CoarseMask CoarseMask
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment