"git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "cb961c87ca18e7ea9fc0735690325fc887a79f04"
Unverified Commit d8127e02 authored by lin bin's avatar lin bin Committed by GitHub
Browse files

[Model Compression] Add global sort for taylor pruner (#3896)

parent c80ed3e9
......@@ -334,6 +334,8 @@ TaylorFOWeightFilter Pruner is a pruner which prunes convolutional layers based
We also provide a dependency-aware mode for this pruner to get better speedup from the pruning. Please reference `dependency-aware <./DependencyAware.rst>`__ for more details.
What's more, we provide a global-sort mode for this pruner which is aligned with paper implementation. Please set parameter 'global_sort' to True when instantiate TaylorFOWeightFilterPruner.
Usage
^^^^^
......
......@@ -218,6 +218,10 @@ def main(args):
}]
else:
if args.global_sort:
print('Enable the global_sort mode')
# only taylor pruner supports global sort mode currently
kw_args['global_sort'] = True
if args.dependency_aware:
dummy_input = get_dummy_input(args, device)
print('Enable the dependency_aware mode')
......@@ -340,6 +344,8 @@ if __name__ == '__main__':
help='target overall target sparsity')
parser.add_argument('--dependency-aware', action='store_true', default=False,
help='toggle dependency aware mode')
parser.add_argument('--global-sort', action='store_true', default=False,
help='toggle global sort mode')
parser.add_argument('--pruner', type=str, default='l1filter',
choices=['level', 'l1filter', 'l2filter', 'slim', 'agp',
'fpgm', 'mean_activation', 'apoz', 'taylorfo'],
......@@ -365,4 +371,4 @@ if __name__ == '__main__':
args.pruner = params['pruner']
args.model = params['model']
main(args)
main(args)
\ No newline at end of file
......@@ -491,14 +491,20 @@ class TaylorFOWeightFilterPruner(IterativePruner):
dummy_input : torch.Tensor
The dummy input to analyze the topology constraints. Note that, the dummy_input
should on the same device with the model.
global_sort: bool
Only support TaylorFOWeightFilterPruner currently.
If prune the model in a global-sort way. If it is `True`, this pruner will prune
the model according to the global contributions information which means channel contributions
will be sorted globally and whether specific channel will be pruned depends on global information.
"""
def __init__(self, model, config_list, optimizer, trainer, criterion, sparsifying_training_batches=1,
dependency_aware=False, dummy_input=None):
dependency_aware=False, dummy_input=None, global_sort=False):
super().__init__(model, config_list, optimizer=optimizer, pruning_algorithm='taylorfo', trainer=trainer,
criterion=criterion, statistics_batch_num=sparsifying_training_batches, num_iterations=1,
epochs_per_iteration=1, dependency_aware=dependency_aware,
dummy_input=dummy_input)
self.masker.global_sort = global_sort
def _supported_dependency_aware(self):
return True
......
......@@ -33,11 +33,12 @@ class StructuredWeightMasker(WeightMasker):
"""
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False):
def __init__(self, model, pruner, preserve_round=1, dependency_aware=False, global_sort=False):
self.model = model
self.pruner = pruner
self.preserve_round = preserve_round
self.dependency_aware = dependency_aware
self.global_sort = global_sort
def calc_mask(self, sparsity, wrapper, wrapper_idx=None, **depen_kwargs):
"""
......@@ -60,7 +61,11 @@ class StructuredWeightMasker(WeightMasker):
depen_kwargs: dict
The kw_args for the dependency-aware mode.
"""
if not self.dependency_aware:
if self.global_sort:
# if the global_sort switch is on, calculate the mask based
# on global model information
return self._global_calc_mask(sparsity, wrapper, wrapper_idx)
elif not self.dependency_aware:
# calculate the mask in the normal way, each layer calculate its
# own mask separately
return self._normal_calc_mask(sparsity, wrapper, wrapper_idx)
......@@ -127,6 +132,12 @@ class StructuredWeightMasker(WeightMasker):
# weight*mask_weight: apply base mask for iterative pruning
return mask, weight * mask_weight, num_prune
def _global_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
num_prune = self._get_global_num_prune(wrapper, wrapper_idx)
mask, weight, _ = self._get_current_state(
sparsity, wrapper, wrapper_idx)
return self.get_mask(mask, weight, num_prune, wrapper, wrapper_idx)
def _normal_calc_mask(self, sparsity, wrapper, wrapper_idx=None):
"""
Calculate the mask of given layer.
......@@ -477,6 +488,31 @@ class TaylorFOWeightFilterPrunerMasker(StructuredWeightMasker):
self.pruner.iterations = 0
self.pruner.set_wrappers_attribute("contribution", None)
self.pruner.patch_optimizer(self.calc_contributions)
self.global_threshold = None
def _get_global_threshold(self):
channel_contribution_list = []
for wrapper_idx, wrapper in enumerate(self.pruner.get_modules_wrapper()):
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
wrapper_size = wrapper.module.weight.size().numel()
channel_size = wrapper.module.weight.size(0)
contribution_expand = channel_contribution.expand(int(wrapper_size / channel_size), channel_size).reshape(-1)
channel_contribution_list.append(contribution_expand)
all_channel_contributions = torch.cat(channel_contribution_list)
k = int(all_channel_contributions.shape[0] * self.pruner.config_list[0]['sparsity'])
self.global_threshold = torch.topk(
all_channel_contributions.view(-1), k, largest=False)[0].max()
def _get_global_num_prune(self, wrapper, wrapper_idx):
if self.global_threshold is None:
self._get_global_threshold()
weight = wrapper.module.weight.data
filters = weight.size(0)
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
num_prune = channel_contribution[channel_contribution < self.global_threshold].size()[0]
if num_prune == filters:
num_prune -= 1
return num_prune
def get_mask(self, base_mask, weight, num_prune, wrapper, wrapper_idx, channel_masks=None):
channel_contribution = self.get_channel_sum(wrapper, wrapper_idx)
......
......@@ -219,6 +219,50 @@ class CompressorTestCase(TestCase):
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 25., 25., 25., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 0., 0., 0., 0., 0., 0., ]))
def test_torch_taylorFOweight_pruner_global_sort(self):
"""
After enabling global_sort, taylorFOweight pruner will calculate contributions and rank topk from all
of the conv operators. Then it will prune low contribution filters depends on the global information.
So if sparsity of conv operator is 0.4, the expected masks should mask out filter 0 and filter 1 together,
this can be verified through:
`all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))`
`all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))`
"""
w1 = np.array([np.zeros((1, 5, 5)), np.ones((1, 5, 5)), np.ones((1, 5, 5)) * 2,
np.ones((1, 5, 5)) * 3, np.ones((1, 5, 5)) * 4])
w2 = np.array([[[[i + 1] * 5] * 5] * 5 for i in range(10)[::-1]])
grad1 = np.array([np.ones((1, 5, 5)) * -1, np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1,
np.ones((1, 5, 5)) * 1, np.ones((1, 5, 5)) * -1])
grad2 = np.array([[[[(-1)**i] * 5] * 5] * 5 for i in range(10)])
config_list = [{'sparsity': 0.4, 'op_types': ['Conv2d']}]
model = TorchModel()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
pruner = torch_pruner.TaylorFOWeightFilterPruner(model, config_list, optimizer, trainer=None, criterion=None, sparsifying_training_batches=1, global_sort=True)
x = torch.rand((1, 1, 28, 28), requires_grad=True)
model.conv1.module.weight.data = torch.tensor(w1).float()
model.conv2.module.weight.data = torch.tensor(w2).float()
y = model(x)
y.backward(torch.ones_like(y))
model.conv1.module.weight.grad.data = torch.tensor(grad1).float()
model.conv2.module.weight.grad.data = torch.tensor(grad2).float()
optimizer.step()
mask1 = pruner.calc_mask(model.conv1)
mask2 = pruner.calc_mask(model.conv2)
print(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy())
print(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy())
assert all(torch.sum(mask1['weight_mask'], (1, 2, 3)).numpy() == np.array([0., 0., 0, 0., 25.]))
assert all(torch.sum(mask2['weight_mask'], (1, 2, 3)).numpy() == np.array([125., 125., 125., 125., 125., 125., 125., 0., 0., 0.]))
def test_torch_QAT_quantizer(self):
model = TorchModel()
config_list = [{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment