Unverified Commit b8c0fb6e authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: add init file (#2063)

parent b4ab371b
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import torch
from .infer_shape import CoarseMask, ModuleMasks
from .infer_shape import ModuleMasks
replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
......
......@@ -56,7 +56,7 @@ class CoarseMask:
s.add(num)
for num in index_b:
s.add(num)
return torch.tensor(sorted(s))
return torch.tensor(sorted(s)) # pylint: disable=not-callable
def merge(self, cmask):
"""
......@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index))
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask)
return output_cmask
......@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask):
"""
assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight']
shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment