Unverified Commit 1abaf627 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression v2] bugfix & improvement (#4307)

parent 0845d79b
......@@ -414,7 +414,7 @@ class SlimPruner(BasicPruner):
def patched_criterion(input_tensor: Tensor, target: Tensor):
sum_l1 = 0
for _, wrapper in self.get_modules_wrapper().items():
sum_l1 += torch.norm(wrapper.module.weight.data, p=1)
sum_l1 += torch.norm(wrapper.module.weight, p=1)
return criterion(input_tensor, target) + self._scale * sum_l1
return patched_criterion
......
......@@ -24,7 +24,7 @@ class WeightDataCollector(DataCollector):
def collect(self) -> Dict[str, Tensor]:
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data.clone().detach()
data[wrapper.name] = wrapper.module.weight.data
return data
......@@ -39,7 +39,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.module.weight.data.clone().detach()
data[wrapper.name] = wrapper.module.weight.data
return data
......
......@@ -24,6 +24,8 @@ class NormalSparsityAllocator(SparsityAllocator):
sparsity_rate = wrapper.config['total_sparsity']
assert name in metrics, 'Metric of %s is not calculated.'
# We assume the metric value are all positive right now.
metric = metrics[name]
if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask)
......@@ -66,8 +68,11 @@ class GlobalSparsityAllocator(SparsityAllocator):
for name, metric in group_metric_dict.items():
wrapper = self.pruner.get_modules_wrapper()[name]
# We assume the metric value are all positive right now.
if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask)
layer_weight_num = wrapper.module.weight.data.numel()
total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel())
......@@ -147,7 +152,8 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
group_mask = torch.cat(group_mask, dim=0)
for name, metric in group_metric_dict.items():
metric = (metric - metric.min()) * group_mask
# We assume the metric value are all positive right now.
metric = metric * group_mask
pruned_num = int(sparsities[name] * len(metric))
threshold = torch.topk(metric, pruned_num, largest=False)[0].max()
mask = torch.gt(metric, threshold).type_as(metric)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment