"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "c993f767a8a5ce041921c1825ac407d9be847bb9"
Unverified Commit e428db54 authored by Panacea's avatar Panacea Committed by GitHub
Browse files

Fix [v2 sparsity_allocator] the way tensor expands (#4265)

parent 52c2d4d3
...@@ -80,7 +80,7 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -80,7 +80,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0] stay_metric = torch.topk(metric.view(-1), stay_metric_num, largest=False)[0]
sub_thresholds[name] = stay_metric.max() sub_thresholds[name] = stay_metric.max()
if expend_times > 1: if expend_times > 1:
stay_metric = stay_metric.expand(stay_metric_num, int(layer_weight_num / metric.numel())).view(-1) stay_metric = stay_metric.expand(int(layer_weight_num / metric.numel()), stay_metric_num).contiguous().view(-1)
metric_list.append(stay_metric) metric_list.append(stay_metric)
total_prune_num = int(total_sparsity * total_weight_num) total_prune_num = int(total_sparsity * total_weight_num)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment