"git@developer.sourcefind.cn:yangql/composable_kernel.git" did not exist on "c82b833d8e76094a3702046d81872132d5c4b15a"
Unverified Commit 7811307c authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Test] fix compression ut (#5129)

* fix test often failed

* update

* fix lint
parent c88ac7b9
...@@ -212,7 +212,6 @@ class DependencyAwareAllocator(SparsityAllocator): ...@@ -212,7 +212,6 @@ class DependencyAwareAllocator(SparsityAllocator):
return fused_metrics return fused_metrics
def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]: def common_target_masks_generation(self, metrics: Dict[str, Dict[str, Tensor]]) -> Dict[str, Dict[str, Tensor]]:
# placeholder, here we need more discussion about dependence sparsity, Plan A or Plan B.
masks = {} masks = {}
# generate public part for modules that have dependencies # generate public part for modules that have dependencies
for module_names in self.channel_dependency: for module_names in self.channel_dependency:
...@@ -228,7 +227,8 @@ class DependencyAwareAllocator(SparsityAllocator): ...@@ -228,7 +227,8 @@ class DependencyAwareAllocator(SparsityAllocator):
group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()] group_nums = [self.group_dependency.get(module_name, 1) for module_name in sub_metrics.keys()]
max_group_nums = int(np.lcm.reduce(group_nums)) max_group_nums = int(np.lcm.reduce(group_nums))
pruned_numel_per_group = int(fused_metric.numel() // max_group_nums * min_sparsity_rate) numel_per_group = fused_metric.numel() // max_group_nums
kept_numel_per_group = numel_per_group - int(numel_per_group * min_sparsity_rate)
group_step = fused_metric.shape[0] // max_group_nums group_step = fused_metric.shape[0] // max_group_nums
# get the public part of the mask of the module with dependencies # get the public part of the mask of the module with dependencies
...@@ -236,9 +236,15 @@ class DependencyAwareAllocator(SparsityAllocator): ...@@ -236,9 +236,15 @@ class DependencyAwareAllocator(SparsityAllocator):
for gid in range(max_group_nums): for gid in range(max_group_nums):
_start = gid * group_step _start = gid * group_step
_end = (gid + 1) * group_step _end = (gid + 1) * group_step
if pruned_numel_per_group > 0: if kept_numel_per_group > 0:
threshold = torch.topk(fused_metric[_start: _end].reshape(-1), pruned_numel_per_group, largest=False)[0].max() flatten_partial_fused_metric = fused_metric[_start: _end].reshape(-1)
dependency_mask[_start: _end] = torch.gt(fused_metric[_start:_end], threshold).type_as(fused_metric) kept_indices = torch.topk(flatten_partial_fused_metric, kept_numel_per_group).indices
flatten_partial_mask = torch.zeros_like(flatten_partial_fused_metric).scatter(0, kept_indices, 1.0)
dependency_mask[_start: _end] = flatten_partial_mask.reshape_as(dependency_mask[_start: _end])
else:
# all zeros means this target will be whole masked, will break the model in most cases,
# maybe replace this layer to identity layer in the future
dependency_mask[_start: _end] = torch.zeros_like(dependency_mask[_start: _end])
# change the metric value corresponding to the public mask part to the minimum value # change the metric value corresponding to the public mask part to the minimum value
for module_name, targets_metric in sub_metrics.items(): for module_name, targets_metric in sub_metrics.items():
...@@ -262,8 +268,9 @@ class DependencyAwareAllocator(SparsityAllocator): ...@@ -262,8 +268,9 @@ class DependencyAwareAllocator(SparsityAllocator):
sparsity_rate = wrapper.config['total_sparsity'] sparsity_rate = wrapper.config['total_sparsity']
prune_num = int(sparsity_rate * target_metric.numel()) prune_num = int(sparsity_rate * target_metric.numel())
if prune_num != 0: if prune_num != 0:
threshold = torch.topk(target_metric.reshape(-1), prune_num, largest=False)[0].max() flatten_metric = target_metric.reshape(-1)
shrinked_mask = torch.gt(target_metric, threshold).type_as(target_metric) kept_indices = torch.topk(flatten_metric, target_metric.numel() - prune_num).indices
shrinked_mask = torch.zeros_like(flatten_metric).scatter(0, kept_indices, 1.0).reshape_as(target_metric)
else: else:
# target_metric should have the same size as shrinked_mask # target_metric should have the same size as shrinked_mask
shrinked_mask = torch.ones_like(target_metric) shrinked_mask = torch.ones_like(target_metric)
......
...@@ -15,12 +15,12 @@ log_dir = Path(__file__).parent.parent / 'logs' ...@@ -15,12 +15,12 @@ log_dir = Path(__file__).parent.parent / 'logs'
def create_model(model_type: str): def create_model(model_type: str):
torch_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5}, torch_config_list = [{'op_types': ['Linear'], 'sparsity': 0.75},
{'op_names': ['conv1', 'conv2', 'conv3'], 'sparsity': 0.5}, {'op_names': ['conv1', 'conv2', 'conv3'], 'sparsity': 0.75},
{'op_names': ['fc2'], 'exclude': True}] {'op_names': ['fc2'], 'exclude': True}]
lightning_config_list = [{'op_types': ['Linear'], 'sparsity': 0.5}, lightning_config_list = [{'op_types': ['Linear'], 'sparsity': 0.75},
{'op_names': ['model.conv1', 'model.conv2', 'model.conv3'], 'sparsity': 0.5}, {'op_names': ['model.conv1', 'model.conv2', 'model.conv3'], 'sparsity': 0.75},
{'op_names': ['model.fc2'], 'exclude': True}] {'op_names': ['model.fc2'], 'exclude': True}]
if model_type == 'lightning': if model_type == 'lightning':
......
...@@ -23,11 +23,11 @@ class SimpleTorchModel(torch.nn.Module): ...@@ -23,11 +23,11 @@ class SimpleTorchModel(torch.nn.Module):
super().__init__() super().__init__()
self.conv1 = torch.nn.Conv2d(1, 16, 3) self.conv1 = torch.nn.Conv2d(1, 16, 3)
self.bn1 = torch.nn.BatchNorm2d(16) self.bn1 = torch.nn.BatchNorm2d(16)
self.conv2 = torch.nn.Conv2d(16, 8, 3, groups=4) self.conv2 = torch.nn.Conv2d(16, 32, 3, groups=4)
self.bn2 = torch.nn.BatchNorm2d(8) self.bn2 = torch.nn.BatchNorm2d(32)
self.conv3 = torch.nn.Conv2d(16, 8, 3) self.conv3 = torch.nn.Conv2d(16, 32, 3)
self.bn3 = torch.nn.BatchNorm2d(8) self.bn3 = torch.nn.BatchNorm2d(32)
self.fc1 = torch.nn.Linear(8 * 24 * 24, 100) self.fc1 = torch.nn.Linear(32 * 24 * 24, 100)
self.fc2 = torch.nn.Linear(100, 10) self.fc2 = torch.nn.Linear(100, 10)
def forward(self, x: torch.Tensor): def forward(self, x: torch.Tensor):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment