Unverified Commit a254f058 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

fix pre-masks inherit (#4428)

parent 6e643b00
...@@ -36,6 +36,8 @@ class NormalSparsityAllocator(SparsityAllocator): ...@@ -36,6 +36,8 @@ class NormalSparsityAllocator(SparsityAllocator):
threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max() threshold = torch.topk(metric.view(-1), prune_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric) mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask) masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= wrapper.weight_mask
return masks return masks
...@@ -55,6 +57,8 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -55,6 +57,8 @@ class GlobalSparsityAllocator(SparsityAllocator):
for name, metric in group_metric_dict.items(): for name, metric in group_metric_dict.items():
mask = torch.gt(metric, min(threshold, sub_thresholds[name])).type_as(metric) mask = torch.gt(metric, min(threshold, sub_thresholds[name])).type_as(metric)
masks[name] = self._expand_mask(name, mask) masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= self.pruner.get_modules_wrapper()[name].weight_mask
return masks return masks
def _calculate_threshold(self, group_metric_dict: Dict[str, Tensor]) -> Tuple[float, Dict[str, float]]: def _calculate_threshold(self, group_metric_dict: Dict[str, Tensor]) -> Tuple[float, Dict[str, float]]:
...@@ -158,7 +162,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator): ...@@ -158,7 +162,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
threshold = torch.topk(metric, pruned_num, largest=False)[0].max() threshold = torch.topk(metric, pruned_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric) mask = torch.gt(metric, threshold).type_as(metric)
masks[name] = self._expand_mask(name, mask) masks[name] = self._expand_mask(name, mask)
if self.continuous_mask:
masks[name]['weight'] *= self.pruner.get_modules_wrapper()[name].weight_mask
return masks return masks
def _group_metric_calculate(self, group_metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor: def _group_metric_calculate(self, group_metrics: Union[Dict[str, Tensor], List[Tensor]]) -> Tensor:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment