Unverified Commit 1abaf627 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression v2] bugfix & improvement (#4307)

parent 0845d79b
...@@ -414,7 +414,7 @@ class SlimPruner(BasicPruner): ...@@ -414,7 +414,7 @@ class SlimPruner(BasicPruner):
def patched_criterion(input_tensor: Tensor, target: Tensor): def patched_criterion(input_tensor: Tensor, target: Tensor):
sum_l1 = 0 sum_l1 = 0
for _, wrapper in self.get_modules_wrapper().items(): for _, wrapper in self.get_modules_wrapper().items():
sum_l1 += torch.norm(wrapper.module.weight.data, p=1) sum_l1 += torch.norm(wrapper.module.weight, p=1)
return criterion(input_tensor, target) + self._scale * sum_l1 return criterion(input_tensor, target) + self._scale * sum_l1
return patched_criterion return patched_criterion
......
...@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector): ...@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Tensor]:
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data.clone().detach() data[wrapper.name] = wrapper.module.weight.data
return data return data
...@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data.clone().detach() data[wrapper.name] = wrapper.module.weight.data
return data return data
......
...@@ -24,6 +24,8 @@ class NormalSparsityAllocator(SparsityAllocator): ...@@ -24,6 +24,8 @@ class NormalSparsityAllocator(SparsityAllocator):
sparsity_rate = wrapper.config['total_sparsity'] sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of %s is not calculated.' assert name in metrics, 'Metric of %s is not calculated.'
# We assume the metric value are all positive right now.
metric = metrics[name] metric = metrics[name]
if self.continuous_mask: if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask) metric *= self._compress_mask(wrapper.weight_mask)
...@@ -66,8 +68,11 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -66,8 +68,11 @@ class GlobalSparsityAllocator(SparsityAllocator):
for name, metric in group_metric_dict.items(): for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name] wrapper = self.pruner.get_modules_wrapper()[name]
# We assume the metric value are all positive right now.
if self.continuous_mask: if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask) metric = metric * self._compress_mask(wrapper.weight_mask)
layer_weight_num = wrapper.module.weight.data.numel() layer_weight_num = wrapper.module.weight.data.numel()
total_weight_num += layer_weight_num total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel()) expend_times = int(layer_weight_num / metric.numel())
...@@ -147,7 +152,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator): ...@@ -147,7 +152,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
group_mask = torch.cat(group_mask, dim=0) group_mask = torch.cat(group_mask, dim=0)
for name, metric in group_metric_dict.items(): for name, metric in group_metric_dict.items():
metric = (metric - metric.min()) * group_mask # We assume the metric value are all positive right now.
metric = metric * group_mask
pruned_num = int(sparsities[name] * len(metric)) pruned_num = int(sparsities[name] * len(metric))
threshold = torch.topk(metric, pruned_num, largest=False)[0].max() threshold = torch.topk(metric, pruned_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric) mask = torch.gt(metric, threshold).type_as(metric)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment