Unverified Commit b8c0fb6e authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

compression speedup: add init file (#2063)

parent b4ab371b
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import torch import torch
from .infer_shape import CoarseMask, ModuleMasks from .infer_shape import ModuleMasks
replace_module = { replace_module = {
'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask), 'BatchNorm2d': lambda module, mask: replace_batchnorm2d(module, mask),
......
...@@ -56,7 +56,7 @@ class CoarseMask: ...@@ -56,7 +56,7 @@ class CoarseMask:
s.add(num) s.add(num)
for num in index_b: for num in index_b:
s.add(num) s.add(num)
return torch.tensor(sorted(s)) return torch.tensor(sorted(s)) # pylint: disable=not-callable
def merge(self, cmask): def merge(self, cmask):
""" """
...@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape): ...@@ -250,7 +250,7 @@ def view_inshape(module_masks, mask, shape):
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.extend([loc * step_size + i for i in range(step_size)])
output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) output_cmask.add_index_mask(dim=1, index=torch.tensor(index)) # pylint: disable=not-callable
module_masks.set_output_mask(output_cmask) module_masks.set_output_mask(output_cmask)
return output_cmask return output_cmask
...@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask): ...@@ -373,7 +373,6 @@ def conv2d_mask(module_masks, mask):
""" """
assert 'weight' in mask assert 'weight' in mask
assert isinstance(mask['weight'], torch.Tensor) assert isinstance(mask['weight'], torch.Tensor)
cmask = None
weight_mask = mask['weight'] weight_mask = mask['weight']
shape = weight_mask.size() shape = weight_mask.size()
ones = torch.ones(shape[1:]).to(weight_mask.device) ones = torch.ones(shape[1:]).to(weight_mask.device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment