Unverified Commit 74250987 authored by CAI, RIZHAO's avatar CAI, RIZHAO Committed by GitHub
Browse files
parent d7920fd2
...@@ -113,7 +113,7 @@ class AGP_Pruner(Pruner): ...@@ -113,7 +113,7 @@ class AGP_Pruner(Pruner):
if k == 0 or target_sparsity >= 1 or target_sparsity <= 0: if k == 0 or target_sparsity >= 1 or target_sparsity <= 0:
return mask return mask
# if we want to generate new mask, we should update weigth first # if we want to generate new mask, we should update weigth first
w_abs = weight.abs() * mask w_abs = weight.abs() * mask['weight']
threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max() threshold = torch.topk(w_abs.view(-1), k, largest=False)[0].max()
new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)} new_mask = {'weight': torch.gt(w_abs, threshold).type_as(weight)}
self.mask_dict.update({op_name: new_mask}) self.mask_dict.update({op_name: new_mask})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment