Commit 7c4e81b5 authored by Tang Lang's avatar Tang Lang Committed by QuanluZhang
Browse files

fix pruner export (#1727)

parent b37fbca8
...@@ -206,6 +206,8 @@ class Pruner(Compressor): ...@@ -206,6 +206,8 @@ class Pruner(Compressor):
""" """
assert model_path is not None, 'model_path must be specified' assert model_path is not None, 'model_path must be specified'
for name, m in self.bound_model.named_modules(): for name, m in self.bound_model.named_modules():
if name == "":
continue
mask = self.mask_dict.get(name) mask = self.mask_dict.get(name)
if mask is not None: if mask is not None:
mask_sum = mask.sum().item() mask_sum = mask.sum().item()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment