Unverified Commit cd98c48f authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] Improve bank sparse execution efficiency (#5033)

parent d03c411c
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import itertools from functools import reduce
from typing import Any, Dict from typing import Any, Dict
import numpy as np import numpy as np
...@@ -69,28 +69,29 @@ class BankSparsityAllocator(SparsityAllocator): ...@@ -69,28 +69,29 @@ class BankSparsityAllocator(SparsityAllocator):
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric' assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran # make up for balance_gran
balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran balance_gran = [1] * (n_dim - len(self.balance_gran)) + self.balance_gran
balance_numel = reduce(lambda x, y: x * y, balance_gran)
reshape_size_split = []
reshape_size_balance = []
for i, j in zip(target_metric.shape, balance_gran): for i, j in zip(target_metric.shape, balance_gran):
assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name) assert i % j == 0, 'Length of {} {} is not aligned with balance granularity'.format(module_name, target_name)
reshape_size_split.extend([i // j, j])
# FIXME: The following code need refactor, do it after scaling refactor is done. reshape_size_balance.append(i // j)
shrinked_mask = torch.ones(target_metric.shape).type_as(target_metric) reshape_size_balance.append(balance_numel)
loop_iters = [range(int(i / j)) for i, j in zip(target_metric.shape, balance_gran)]
for iter_params in itertools.product(*loop_iters): permute_dims_balance = [_ * 2 for _ in range(n_dim)] + [_ * 2 + 1 for _ in range(n_dim)]
index_str_list = [f"{iter_param * gran}:{(iter_param+1) * gran}"\ _target_metric = target_metric.reshape(reshape_size_split).permute(permute_dims_balance)
for iter_param, gran in zip(iter_params, balance_gran)] reshape_size_split_p = _target_metric.shape
index_str = ",".join(index_str_list) balance_metric = _target_metric.reshape(reshape_size_balance)
sub_metric_str = "target_metric[{}]".format(index_str)
sub_mask_str = "shrinked_mask[{}] = mask_bank".format(index_str) kept_num = balance_numel - int(sparsity_rate * balance_numel)
metric_bank: Tensor = eval(sub_metric_str) kept_indices = torch.topk(balance_metric, kept_num).indices
prune_num = int(sparsity_rate * metric_bank.numel()) shrinked_mask = torch.zeros_like(balance_metric).scatter(-1, kept_indices, 1.0).reshape(reshape_size_split_p)
# mask_bank will be used in exec(sub_mask_str)
if prune_num != 0: permute_dims_split = []
threshold = torch.topk(metric_bank.reshape(-1), prune_num, largest=False)[0].max() for i in range(n_dim):
mask_bank = torch.gt(metric_bank, threshold).type_as(metric_bank) # type: ignore permute_dims_split.extend([i, i + n_dim])
else: shrinked_mask = shrinked_mask.permute(permute_dims_split).reshape_as(target_metric)
mask_bank = torch.ones_like(metric_bank) # type: ignore
mask_bank = mask_bank # `type: ignore` is useless for unused-variable error, add this line to workaround
exec(sub_mask_str)
masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask) masks[module_name][target_name] = self._expand_mask(module_name, target_name, shrinked_mask)
return masks return masks
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment