Unverified Commit cd98c48f authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] Improve bank sparse execution efficiency (#5033)

parent d03c411c
......@@ -3,7 +3,7 @@
from __future__ import annotations
import itertools
from functools import reduce
from typing import Any, Dict
import numpy as np
......@@ -69,28 +69,29 @@ class BankSparsityAllocator(SparsityAllocator):
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
balance_numel = reduce(lambda x, y: x * y, balance_gran)
reshape_size_split = []
reshape_size_balance = []
for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
# FIXME: The following code need refactor, do it after scaling refactor is done.
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters):
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\
for iter_param, gran in zip(iter_params, balance_gran)]
index_str = ",".join(index_str_list)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str)
metric_bank: Tensor = eval(sub_metric_str)
prune_num = int(sparsity_rate * metric_bank.numel())
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0:
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max()
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank) # type: ignore
else:
mask_bank = torch.ones_like(metric_bank) # type: ignore
mask_bank = mask_bank # `type: ignore` is useless for unused-variable error, add this line to workaround
exec(sub_mask_str)
reshape_size_split.extend([i // j, j])
reshape_size_balance.append(i // j)
reshape_size_balance.append(balance_numel)
permute_dims_balance = [_ * 2 for _ in range(n_dim)] + [_ * 2 + 1 for _ in range(n_dim)]
_target_metric = target_metric.reshape(reshape_size_split).permute(permute_dims_balance)
reshape_size_split_p = _target_metric.shape
balance_metric = _target_metric.reshape(reshape_size_balance)
kept_num = balance_numel - int(sparsity_rate * balance_numel)
kept_indices = torch.topk(balance_metric, kept_num).indices
shrinked_mask = torch.zeros_like(balance_metric).scatter(-1, kept_indices, 1.0).reshape(reshape_size_split_p)
permute_dims_split = []
for i in range(n_dim):
permute_dims_split.extend([i, i + n_dim])
shrinked_mask = shrinked_mask.permute(permute_dims_split).reshape_as(target_metric)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment