Unverified Commit f102d5bd authored by Xuesong Wang's avatar Xuesong Wang Committed by GitHub
Browse files

fix the bug in nni/compression/pytorch/speedup/infer_shape.py, line no. 625-626 (#3588)

parent 3725c8fc
......@@ -597,15 +597,15 @@ def view_outshape(module_masks, mask, shape):
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the ```flatten``` op
The ModuleMasks instance of the ```view``` op
mask : CoarseMask
The mask of its input tensor
The mask of its output tensor
shape : dict
Original shape of its input and output tensors
Returns
-------
CoarseMask
The mask of its output tensor
The mask of its input tensor
"""
# NOTE: the case constrained by the following four asserts
assert shape['in_shape'][0] == shape['out_shape'][0]
......@@ -620,10 +620,11 @@ def view_outshape(module_masks, mask, shape):
module_masks.set_output_mask(mask)
input_cmask = CoarseMask(num_dim=4)
index = []
index = set()
step_size = shape['in_shape'][2] * shape['in_shape'][3]
for loc in mask.mask_index[1]:
index.extend([loc * step_size + i for i in range(step_size)])
index.add(loc // step_size)
index = sorted(list(index))
input_cmask.add_index_mask(dim=1, index=torch.tensor(index).to(mask.mask_index[1].device)) # pylint: disable=not-callable
module_masks.set_input_mask(input_cmask)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment