"docs/en_US/Tutorial/InstallationLinux.md" did not exist on "9daf7c9506aa39c87682547260fc902a7e6e9476"
Unverified Commit d8127e02 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Model Compression] Add global sort for taylor pruner (#3896)

parent c80ed3e9
...@@ -334,6 +334,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based ...@@ -334,6 +334,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference `dependency-aware <./DependencyAware.rst>`__ for more details. We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference `dependency-aware <./DependencyAware.rst>`__ for more details.
What's more, we provide a global-sort mode for this pruner which is aligned with paper implementation. Please set parameter 'global_sort' to True when instantiate TaylorFOWeightFilterPruner.
Usage Usage
^^^^^ ^^^^^
......
...@@ -218,6 +218,10 @@ def main(args): ...@@ -218,6 +218,10 @@ def main(args):
}] }]
else: else:
if args.global_sort:
print('Enable the global_sort mode')
# only taylor pruner supports global sort mode currently
kw_args['global_sort'] = True
if args.dependency_aware: if args.dependency_aware:
dummy_input = get_dummy_input(args, device) dummy_input = get_dummy_input(args, device)
print('Enable the dependency_aware mode') print('Enable the dependency_aware mode')
...@@ -340,6 +344,8 @@ if __name__ == '__main__': ...@@ -340,6 +344,8 @@ if __name__ == '__main__':
help='target overall target sparsity') help='target overall target sparsity')
parser.add_argument('--dependency-aware', action='store_true', default=False, parser.add_argument('--dependency-aware', action='store_true', default=False,
help='toggle dependency aware mode') help='toggle dependency aware mode')
parser.add_argument('--global-sort', action='store_true', default=False,
help='toggle global sort mode')
parser.add_argument('--pruner', type=str, default='l1filter', parser.add_argument('--pruner', type=str, default='l1filter',
choices=['level', 'l1filter', 'l2filter', 'slim', 'agp', choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
'fpgm', 'mean_activation', 'apoz', 'taylorfo'], 'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
...@@ -365,4 +371,4 @@ if __name__ == '__main__': ...@@ -365,4 +371,4 @@ if __name__ == '__main__':
args.pruner = params['pruner'] args.pruner = params['pruner']
args.model = params['model'] args.model = params['model']
main(args) main(args)
\ No newline at end of file
...@@ -491,14 +491,20 @@ class TaylorFOWeightFilterPruner(IterativePruner): ...@@ -491,14 +491,20 @@ class TaylorFOWeightFilterPruner(IterativePruner):
dummy_input : torch.Tensor dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model. should on the same device with the model.
global_sort: bool
Only support TaylorFOWeightFilterPruner currently.
If prune the model in a global-sort way. If it is `True`, this pruner will prune
the model according to the global contributions information which means channel contributions
will be sorted globally and whether specific channel will be pruned depends on global information.
""" """
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1, def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1,
dependency_aware=False, dummy_input=None): dependency_aware=False, dummy_input=None, global_sort=False):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer, super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer,
criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1, criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1, dependency_aware=dependency_aware, epochs_per_iteration=1, dependency_aware=dependency_aware,
dummy_input=dummy_input) dummy_input=dummy_input)
self.masker.global_sort = global_sort
def _supported_dependency_aware(self): def _supported_dependency_aware(self):
return True return True
......
...@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker): ...@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker):
""" """
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False): def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False):
self.model = model self.model = model
self.pruner = pruner self.pruner = pruner
self.preserve_round = preserve_round self.preserve_round = preserve_round
self.dependency_aware = dependency_aware self.dependency_aware = dependency_aware
self.global_sort = global_sort
def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs): def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs):
""" """
...@@ -60,7 +61,11 @@ class StructuredWeightMasker(WeightMasker): ...@@ -60,7 +61,11 @@ class StructuredWeightMasker(WeightMasker):
depen_kwargs: dict depen_kwargs: dict
The kw_args for the dependency-aware mode. The kw_args for the dependency-aware mode.
""" """
if not self.dependency_aware: if self.global_sort:
# if the global_sort switch is on, calculate the mask based
# on global model information
return self._global_calc_mask(sparsity, wrapper, wrapper_idx)
elif not self.dependency_aware:
# calculate the mask in the normal way, each layer calculate its # calculate the mask in the normal way, each layer calculate its
# own mask separately # own mask separately
return self._normal_calc_mask(sparsity, wrapper, wrapper_idx) return self._normal_calc_mask(sparsity, wrapper, wrapper_idx)
...@@ -127,6 +132,12 @@ class StructuredWeightMasker(WeightMasker): ...@@ -127,6 +132,12 @@ class StructuredWeightMasker(WeightMasker):
# weight*mask_weight: apply base mask for iterative pruning # weight*mask_weight: apply base mask for iterative pruning
return mask, weight * mask_weight, num_prune return mask, weight * mask_weight, num_prune
def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
num_prune = self._get_global_num_prune(wrapper, wrapper_idx)
mask, weight, _ = self._get_current_state(
sparsity, wrapper, wrapper_idx)
return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx)
def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None): def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
""" """
Calculate the mask of given layer. Calculate the mask of given layer.
...@@ -477,6 +488,31 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker): ...@@ -477,6 +488,31 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
self.pruner.iterations = 0 self.pruner.iterations = 0
self.pruner.set_wrappers_attribute("contribution", None) self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.patch_optimizer(self.calc_contributions) self.pruner.patch_optimizer(self.calc_contributions)
self.global_threshold = None
def _get_global_threshold(self):
channel_contribution_list = []
for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()):
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
wrapper_size = wrapper.module.weight.size().numel()
channel_size = wrapper.module.weight.size(0)
contribution_expand = channel_contribution.expand(int(wrapper_size / channel_size), channel_size).reshape(-1)
channel_contribution_list.append(contribution_expand)
all_channel_contributions = torch.cat(channel_contribution_list)
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity'])
self.global_threshold = torch.topk(
all_channel_contributions.view(-1), k, largest=False)[0].max()
def _get_global_num_prune(self, wrapper, wrapper_idx):
if self.global_threshold is None:
self._get_global_threshold()
weight = wrapper.module.weight.data
filters = weight.size(0)
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0]
if num_prune == filters:
num_prune -= 1
return num_prune
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None): def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx) channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
......
...@@ -219,6 +219,50 @@ class CompressorTestCase(TestCase): ...@@ -219,6 +219,50 @@ class CompressorTestCase(TestCase):
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.])) assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ])) assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))
def test_torch_taylorFOweight_pruner_global_sort(self):
"""
After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all
of the conv operators. Then it will prune low contribution filters depends on the global information.
So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together,
this can be verified through:
`all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))`
`all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))`
"""
w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])
grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])
grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])
config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True)
x = torch.rand((1, 1, 28, 28), requires_grad=True)
model.conv1.module.weight.data = torch.tensor(w1).float()
model.conv2.module.weight.data = torch.tensor(w2).float()
y = model(x)
y.backward(torch.ones_like(y))
model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
optimizer.step()
mask1 = pruner.calc_mask(model.conv1)
mask2 = pruner.calc_mask(model.conv2)
print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy())
print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy())
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))
def test_torch_QAT_quantizer(self): def test_torch_QAT_quantizer(self):
model = TorchModel() model = TorchModel()
config_list = [{ config_list = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment