Unverified Commit f102d5bd authored by Xuesong Wang's avatar Xuesong Wang Committed by GitHub
Browse files

fix the bug in nni/compression/pytorch/speedup/infer_shape.py, line no. 625-626 (#3588)

parent 3725c8fc
...@@ -597,15 +597,15 @@ def view_outshape(module_masks, mask, shape): ...@@ -597,15 +597,15 @@ def view_outshape(module_masks, mask, shape):
Parameters Parameters
---------- ----------
module_masks : ModuleMasks module_masks : ModuleMasks
The ModuleMasks instance of the ```flatten``` op The ModuleMasks instance of the ```view``` op
mask : CoarseMask mask : CoarseMask
The mask of its input tensor The mask of its output tensor
shape : dict shape : dict
Original shape of its input and output tensors Original shape of its input and output tensors
Returns Returns
------- -------
CoarseMask CoarseMask
The mask of its output tensor The mask of its input tensor
""" """
# NOTE: the case constrained by the following four asserts # NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0] assert shape['in_shape'][0] == shape['out_shape'][0]
...@@ -620,10 +620,11 @@ def view_outshape(module_masks, mask, shape): ...@@ -620,10 +620,11 @@ def view_outshape(module_masks, mask, shape):
module_masks.set_output_mask(mask) module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4) input_cmask = CoarseMask(num_dim=4)
index = [] index = set()
step_size = shape['in_shape'][2] * shape['in_shape'][3] step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]: for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)]) index.add(loc // step_size)
index = sorted(list(index))
input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask) module_masks.set_input_mask(input_cmask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment