Unverified Commit aea82f23 authored by Ningxin Zheng's avatar Ningxin Zheng Committed by GitHub
Browse files

Fix a small bug in the dependency_aware mode. (#3143)


Signed-off-by: default avatarNingxin <Ningxin.Zheng@microsoft.com>
parent 95f731e4
...@@ -278,7 +278,8 @@ class StructuredWeightMasker(WeightMasker): ...@@ -278,7 +278,8 @@ class StructuredWeightMasker(WeightMasker):
sparsity, _w, _w_idx) sparsity, _w, _w_idx)
num_total = current_weight.size(0) num_total = current_weight.size(0)
if num_total < 2 or num_prune < 1: if num_total < 2 or num_prune < 1:
return base_mask masks[name] = base_mask
continue
_tmp_mask = self.get_mask( _tmp_mask = self.get_mask(
base_mask, current_weight, num_prune, _w, _w_idx, channel_masks) base_mask, current_weight, num_prune, _w, _w_idx, channel_masks)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment