"docs/source/vscode:/vscode.git/clone" did not exist on "c4a4750cb31fdc6641d86cc8165cc9fccedf0a91"
Unverified Commit cbac2c5c authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] fix typehints (#4800)

parent d49864ce
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import collections
import logging import logging
from typing import List, Dict, Optional, Tuple, Any from typing import Any, List, Dict, Optional, Tuple
import torch import torch
from torch.nn import Module from torch.nn import Module
...@@ -29,7 +28,33 @@ def _setattr(model: Module, name: str, module: Module): ...@@ -29,7 +28,33 @@ def _setattr(model: Module, name: str, module: Module):
name_list = name.split(".") name_list = name.split(".")
setattr(parent_module, name_list[-1], module) setattr(parent_module, name_list[-1], module)
else: else:
raise '{} not exist.'.format(name) raise Exception('{} not exist.'.format(name))
class ModuleWrapper(Module):
"""
Wrap a module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module
The module user wants to compress.
config
The configurations that users specify for compression.
module_name
The name of the module to compress, wrapper module shares same name.
"""
def __init__(self, module: Module, module_name: str, config: Dict):
super().__init__()
# origin layer information
self.module = module
self.name = module_name
# config information
self.config = config
def forward(self, *inputs):
raise NotImplementedError
class Compressor: class Compressor:
...@@ -46,7 +71,7 @@ class Compressor: ...@@ -46,7 +71,7 @@ class Compressor:
def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]): def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
self.is_wrapped = False self.is_wrapped = False
if model is not None: if model is not None and config_list is not None:
self.reset(model=model, config_list=config_list) self.reset(model=model, config_list=config_list)
else: else:
_logger.warning('This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.') _logger.warning('This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.')
...@@ -63,6 +88,7 @@ class Compressor: ...@@ -63,6 +88,7 @@ class Compressor:
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress. The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
""" """
assert isinstance(model, Module), 'Only support compressing pytorch Module, but the type of model is {}.'.format(type(model)) assert isinstance(model, Module), 'Only support compressing pytorch Module, but the type of model is {}.'.format(type(model))
self.bound_model = model self.bound_model = model
self.config_list = config_list self.config_list = config_list
self.validate_config(model=model, config_list=config_list) self.validate_config(model=model, config_list=config_list)
...@@ -70,7 +96,7 @@ class Compressor: ...@@ -70,7 +96,7 @@ class Compressor:
self._unwrap_model() self._unwrap_model()
self._modules_to_compress = None self._modules_to_compress = None
self.modules_wrapper = collections.OrderedDict() self.modules_wrapper = {}
for layer, config in self._detect_modules_to_compress(): for layer, config in self._detect_modules_to_compress():
wrapper = self._wrap_modules(layer, config) wrapper = self._wrap_modules(layer, config)
self.modules_wrapper[layer.name] = wrapper self.modules_wrapper[layer.name] = wrapper
...@@ -93,6 +119,8 @@ class Compressor: ...@@ -93,6 +119,8 @@ class Compressor:
Detect all modules should be compressed, and save the result in `self._modules_to_compress`. Detect all modules should be compressed, and save the result in `self._modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method. The model will be instrumented and user should never edit it after calling this method.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
if self._modules_to_compress is None: if self._modules_to_compress is None:
self._modules_to_compress = [] self._modules_to_compress = []
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
...@@ -118,6 +146,8 @@ class Compressor: ...@@ -118,6 +146,8 @@ class Compressor:
Optional[Dict] Optional[Dict]
The retrieved configuration for this layer, if None, this layer should not be compressed. The retrieved configuration for this layer, if None, this layer should not be compressed.
""" """
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
ret = None ret = None
for config in self.config_list: for config in self.config_list:
config = config.copy() config = config.copy()
...@@ -142,32 +172,26 @@ class Compressor: ...@@ -142,32 +172,26 @@ class Compressor:
return None return None
return ret return ret
def get_modules_wrapper(self) -> Dict[str, Module]: def get_modules_wrapper(self) -> Dict[str, ModuleWrapper]:
""" """
Returns Returns
------- -------
OrderedDict[str, Module] Dict[str, ModuleWrapper]
An ordered dict, key is the name of the module, value is the wrapper of the module. An dict, key is the name of the module, value is the wrapper of the module.
""" """
return self.modules_wrapper raise NotImplementedError
def _wrap_model(self): def _wrap_model(self):
""" """
Wrap all modules that needed to be compressed. Wrap all modules that needed to be compressed.
""" """
if not self.is_wrapped: raise NotImplementedError
for _, wrapper in reversed(self.get_modules_wrapper().items()):
_setattr(self.bound_model, wrapper.name, wrapper)
self.is_wrapped = True
def _unwrap_model(self): def _unwrap_model(self):
""" """
Unwrap all modules that needed to be compressed. Unwrap all modules that needed to be compressed.
""" """
if self.is_wrapped: raise NotImplementedError
for _, wrapper in self.get_modules_wrapper().items():
_setattr(self.bound_model, wrapper.name, wrapper.module)
self.is_wrapped = False
def set_wrappers_attribute(self, name: str, value: Any): def set_wrappers_attribute(self, name: str, value: Any):
""" """
...@@ -182,7 +206,7 @@ class Compressor: ...@@ -182,7 +206,7 @@ class Compressor:
value value
Value of the variable. Value of the variable.
""" """
for wrapper in self.get_modules_wrapper(): for wrapper in self.get_modules_wrapper().values():
if isinstance(value, torch.Tensor): if isinstance(value, torch.Tensor):
wrapper.register_buffer(name, value.clone()) wrapper.register_buffer(name, value.clone())
else: else:
...@@ -216,8 +240,10 @@ class Compressor: ...@@ -216,8 +240,10 @@ class Compressor:
Dict[int, List[str]] Dict[int, List[str]]
A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}. A dict. The key is the config idx in config_list, the value is the module name list. i.e., {1: ['layer.0', 'layer.2']}.
""" """
self._unwrap_model() assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
self._unwrap_model()
module_groups = {} module_groups = {}
for name, module in self.bound_model.named_modules(): for name, module in self.bound_model.named_modules():
if module == self.bound_model: if module == self.bound_model:
...@@ -259,7 +285,7 @@ class Compressor: ...@@ -259,7 +285,7 @@ class Compressor:
""" """
raise NotImplementedError() raise NotImplementedError()
def _wrap_modules(self, layer: LayerInfo, config: Dict): def _wrap_modules(self, layer: LayerInfo, config: Dict) -> ModuleWrapper:
""" """
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...@@ -297,4 +323,6 @@ class Compressor: ...@@ -297,4 +323,6 @@ class Compressor:
torch.nn.Module torch.nn.Module
model with specified modules compressed. model with specified modules compressed.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
assert self.config_list is not None, 'No config_list set in this compressor, please use Compressor.reset(model, config_list) to set it.'
return self.bound_model return self.bound_model
...@@ -2,11 +2,12 @@ ...@@ -2,11 +2,12 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
from typing import Dict, List, Optional, Tuple from typing import Dict, List, Optional, Tuple, OrderedDict
import torch import torch
from torch import Tensor from torch import Tensor
from torch.nn import Module, Parameter from torch.nn import Module
from torch.nn.parameter import Parameter
from .compressor import Compressor, LayerInfo, _setattr from .compressor import Compressor, LayerInfo, _setattr
...@@ -37,15 +38,15 @@ class PrunerModuleWrapper(Module): ...@@ -37,15 +38,15 @@ class PrunerModuleWrapper(Module):
# config information # config information
self.config = config self.config = config
self.weight = Parameter(torch.empty(self.module.weight.size())) pruning_target_names = ['weight', 'bias']
for pruning_target_name in pruning_target_names:
# register buffer for mask pruning_target_mask_name = '{}_mask'.format(pruning_target_name)
self.register_buffer("weight_mask", torch.ones(self.module.weight.shape)) pruning_target = getattr(self.module, pruning_target_name, None)
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, pruning_target_name) and pruning_target is not None:
self.register_buffer("bias_mask", torch.ones(self.module.bias.shape)) setattr(self, pruning_target_name, Parameter(torch.empty(pruning_target.shape)))
self.bias = Parameter(torch.empty(self.module.bias.size())) self.register_buffer(pruning_target_mask_name, torch.ones(pruning_target.shape))
else: else:
self.register_buffer("bias_mask", None) self.register_buffer(pruning_target_mask_name, None)
def _weight2buffer(self): def _weight2buffer(self):
""" """
...@@ -89,7 +90,17 @@ class Pruner(Compressor): ...@@ -89,7 +90,17 @@ class Pruner(Compressor):
def reset(self, model: Optional[Module] = None, config_list: Optional[List[Dict]] = None): def reset(self, model: Optional[Module] = None, config_list: Optional[List[Dict]] = None):
super().reset(model=model, config_list=config_list) super().reset(model=model, config_list=config_list)
def _wrap_modules(self, layer: LayerInfo, config: Dict): def get_modules_wrapper(self) -> OrderedDict[str, PrunerModuleWrapper]:
"""
Returns
-------
OrderedDict[str, PrunerModuleWrapper]
An ordered dict, key is the name of the module, value is the wrapper of the module.
"""
assert self.modules_wrapper is not None, 'Bound model has not be wrapped.'
return self.modules_wrapper
def _wrap_modules(self, layer: LayerInfo, config: Dict) -> PrunerModuleWrapper:
""" """
Create a wrapper module to replace the original one. Create a wrapper module to replace the original one.
...@@ -99,6 +110,11 @@ class Pruner(Compressor): ...@@ -99,6 +110,11 @@ class Pruner(Compressor):
The layer to instrument the mask. The layer to instrument the mask.
config config
The configuration for generating the mask. The configuration for generating the mask.
Returns
-------
PrunerModuleWrapper
The wrapper of the module in layerinfo.
""" """
_logger.debug("Module detected to compress : %s.", layer.name) _logger.debug("Module detected to compress : %s.", layer.name)
wrapper = PrunerModuleWrapper(layer.module, layer.name, config) wrapper = PrunerModuleWrapper(layer.module, layer.name, config)
...@@ -114,8 +130,10 @@ class Pruner(Compressor): ...@@ -114,8 +130,10 @@ class Pruner(Compressor):
Wrap all modules that needed to be compressed. Wrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper. Different from the parent function, call `wrapper._weight2buffer()` after replace the origin module to wrapper.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
if not self.is_wrapped: if not self.is_wrapped:
for _, wrapper in reversed(self.get_modules_wrapper().items()): for _, wrapper in reversed(list(self.get_modules_wrapper().items())):
_setattr(self.bound_model, wrapper.name, wrapper) _setattr(self.bound_model, wrapper.name, wrapper)
wrapper._weight2buffer() wrapper._weight2buffer()
self.is_wrapped = True self.is_wrapped = True
...@@ -125,8 +143,10 @@ class Pruner(Compressor): ...@@ -125,8 +143,10 @@ class Pruner(Compressor):
Unwrap all modules that needed to be compressed. Unwrap all modules that needed to be compressed.
Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module. Different from the parent function, call `wrapper._weight2parameter()` after replace the wrapper to origin module.
""" """
assert self.bound_model is not None, 'No model bounded in this compressor, please use Compressor.reset(model, config_list) to set it.'
if self.is_wrapped: if self.is_wrapped:
for _, wrapper in self.get_modules_wrapper().items(): for wrapper in self.get_modules_wrapper().values():
_setattr(self.bound_model, wrapper.name, wrapper.module) _setattr(self.bound_model, wrapper.name, wrapper.module)
wrapper._weight2parameter() wrapper._weight2parameter()
self.is_wrapped = False self.is_wrapped = False
...@@ -191,7 +211,7 @@ class Pruner(Compressor): ...@@ -191,7 +211,7 @@ class Pruner(Compressor):
dim dim
The pruned dim. The pruned dim.
""" """
for _, wrapper in self.get_modules_wrapper().items(): for wrapper in self.get_modules_wrapper().values():
weight_mask = wrapper.weight_mask weight_mask = wrapper.weight_mask
mask_size = weight_mask.size() mask_size = weight_mask.size()
if len(mask_size) == 1: if len(mask_size) == 1:
......
...@@ -5,7 +5,7 @@ import gc ...@@ -5,7 +5,7 @@ import gc
import logging import logging
import os import os
from pathlib import Path from pathlib import Path
from typing import List, Dict, Tuple, Optional from typing import List, Dict, Tuple, Optional, Union
import json_tricks import json_tricks
import torch import torch
...@@ -19,7 +19,7 @@ class Task: ...@@ -19,7 +19,7 @@ class Task:
# NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync. # NOTE: If we want to support multi-thread, this part need to refactor, maybe use file and lock to sync.
_reference_counter = {} _reference_counter = {}
def __init__(self, task_id: int, model_path: str, masks_path: str, config_list_path: str, def __init__(self, task_id: int, model_path: Union[str, Path], masks_path: Union[str, Path], config_list_path: Union[str, Path],
speedup: Optional[bool] = True, finetune: Optional[bool] = True, evaluate: Optional[bool] = True): speedup: Optional[bool] = True, finetune: Optional[bool] = True, evaluate: Optional[bool] = True):
""" """
Parameters Parameters
...@@ -87,7 +87,7 @@ class Task: ...@@ -87,7 +87,7 @@ class Task:
config_list = json_tricks.load(f) config_list = json_tricks.load(f)
return model, masks, config_list return model, masks, config_list
def referenced_paths(self) -> List[str]: def referenced_paths(self) -> List[Union[str, Path]]:
""" """
Return the path list that need to count reference in this task. Return the path list that need to count reference in this task.
""" """
...@@ -111,7 +111,7 @@ class Task: ...@@ -111,7 +111,7 @@ class Task:
class TaskResult: class TaskResult:
def __init__(self, task_id: int, compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]], def __init__(self, task_id: Union[int, str], compact_model: Module, compact_model_masks: Dict[str, Dict[str, Tensor]],
pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None: pruner_generated_masks: Dict[str, Dict[str, Tensor]], score: Optional[float]) -> None:
""" """
Parameters Parameters
......
...@@ -82,12 +82,13 @@ class AMCTaskGenerator(TaskGenerator): ...@@ -82,12 +82,13 @@ class AMCTaskGenerator(TaskGenerator):
def generate_tasks(self, task_result: TaskResult) -> List[Task]: def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# append experience & update agent policy # append experience & update agent policy
if task_result.task_id != 'origin': if self.action is not None:
action, reward, observation, done = self.env.step(self.action, task_result.compact_model) action, reward, observation, done = self.env.step(self.action, task_result.compact_model)
self.T.append([reward, self.observation, observation, self.action, done]) self.T.append([reward, self.observation, observation, self.action, done])
self.observation = observation.copy() self.observation = observation.copy()
if done: if done:
assert task_result.score is not None, 'task_result.score should not be None if environment is done.'
final_reward = task_result.score - 1 final_reward = task_result.score - 1
# agent observe and update policy # agent observe and update policy
for _, s_t, s_t1, a_t, d_t in self.T: for _, s_t, s_t1, a_t, d_t in self.T:
......
...@@ -46,7 +46,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator): ...@@ -46,7 +46,9 @@ class AutoCompressTaskGenerator(LotteryTicketTaskGenerator):
def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]): def allocate_sparsity(self, new_config_list: List[Dict], model: Module, masks: Dict[str, Dict[str, Tensor]]):
self._iterative_pruner_reset(model, new_config_list, masks) self._iterative_pruner_reset(model, new_config_list, masks)
self.iterative_pruner.compress() self.iterative_pruner.compress()
_, _, _, _, config_list = self.iterative_pruner.get_best_result() best_result = self.iterative_pruner.get_best_result()
assert best_result is not None, 'Best result does not exist, iterative pruner may not start pruning.'
_, _, _, _, config_list = best_result
return config_list return config_list
...@@ -149,7 +151,7 @@ class AutoCompressPruner(IterativePruner): ...@@ -149,7 +151,7 @@ class AutoCompressPruner(IterativePruner):
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict, def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False, sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False,
dummy_input: Optional[Tensor] = None, evaluator: Callable[[Module], float] = None): dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None):
task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration, task_generator = AutoCompressTaskGenerator(total_iteration=total_iteration,
origin_model=model, origin_model=model,
origin_config_list=config_list, origin_config_list=config_list,
......
...@@ -8,7 +8,7 @@ from typing import List, Dict, Tuple, Callable, Optional ...@@ -8,7 +8,7 @@ from typing import List, Dict, Tuple, Callable, Optional
from schema import And, Or, Optional as SchemaOptional, SchemaError from schema import And, Or, Optional as SchemaOptional, SchemaError
import torch import torch
from torch import Tensor from torch import Tensor
import torch.nn as nn import torch.nn.functional as F
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
...@@ -77,10 +77,10 @@ INTERNAL_SCHEMA = { ...@@ -77,10 +77,10 @@ INTERNAL_SCHEMA = {
class BasicPruner(Pruner): class BasicPruner(Pruner):
def __init__(self, model: Module, config_list: List[Dict]): def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]]):
self.data_collector: DataCollector = None self.data_collector: Optional[DataCollector] = None
self.metrics_calculator: MetricsCalculator = None self.metrics_calculator: Optional[MetricsCalculator] = None
self.sparsity_allocator: SparsityAllocator = None self.sparsity_allocator: Optional[SparsityAllocator] = None
super().__init__(model, config_list) super().__init__(model, config_list)
...@@ -114,6 +114,8 @@ class BasicPruner(Pruner): ...@@ -114,6 +114,8 @@ class BasicPruner(Pruner):
Tuple[Module, Dict] Tuple[Module, Dict]
Return the wrapped model and mask. Return the wrapped model and mask.
""" """
assert self.bound_model is not None and self.config_list is not None, 'Model and/or config_list are not set in this pruner, please set them by reset() before compress().'
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
data = self.data_collector.collect() data = self.data_collector.collect()
_logger.debug('Collected Data:\n%s', data) _logger.debug('Collected Data:\n%s', data)
metrics = self.metrics_calculator.calculate_metrics(data) metrics = self.metrics_calculator.calculate_metrics(data)
...@@ -553,8 +555,8 @@ class SlimPruner(BasicPruner): ...@@ -553,8 +555,8 @@ class SlimPruner(BasicPruner):
def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]: def criterion_patch(self, criterion: Callable[[Tensor, Tensor], Tensor]) -> Callable[[Tensor, Tensor], Tensor]:
def patched_criterion(input_tensor: Tensor, target: Tensor): def patched_criterion(input_tensor: Tensor, target: Tensor):
sum_l1 = 0 sum_l1 = 0
for _, wrapper in self.get_modules_wrapper().items(): for wrapper in self.get_modules_wrapper().values():
sum_l1 += torch.norm(wrapper.module.weight, p=1) sum_l1 += torch.norm(wrapper.module.weight, p=1) # type: ignore
return criterion(input_tensor, target) + self._scale * sum_l1 return criterion(input_tensor, target) + self._scale * sum_l1
return patched_criterion return patched_criterion
...@@ -654,11 +656,11 @@ class ActivationPruner(BasicPruner): ...@@ -654,11 +656,11 @@ class ActivationPruner(BasicPruner):
def _choose_activation(self, activation: str = 'relu') -> Callable: def _choose_activation(self, activation: str = 'relu') -> Callable:
if activation == 'relu': if activation == 'relu':
return nn.functional.relu return F.relu
elif activation == 'relu6': elif activation == 'relu6':
return nn.functional.relu6 return F.relu6
else: else:
raise 'Unsupported activatoin {}'.format(activation) raise Exception('Unsupported activatoin {}'.format(activation))
def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]: def _collector(self, buffer: List) -> Callable[[Module, Tensor, Tensor], None]:
assert len(buffer) == 0, 'Buffer pass to activation pruner collector is not empty.' assert len(buffer) == 0, 'Buffer pass to activation pruner collector is not empty.'
...@@ -684,7 +686,7 @@ class ActivationPruner(BasicPruner): ...@@ -684,7 +686,7 @@ class ActivationPruner(BasicPruner):
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info]) 1, collector_infos=[collector_info])
else: else:
self.data_collector.reset(collector_infos=[collector_info]) self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None: if self.metrics_calculator is None:
self.metrics_calculator = self._get_metrics_calculator() self.metrics_calculator = self._get_metrics_calculator()
if self.sparsity_allocator is None: if self.sparsity_allocator is None:
...@@ -999,13 +1001,13 @@ class TaylorFOWeightPruner(BasicPruner): ...@@ -999,13 +1001,13 @@ class TaylorFOWeightPruner(BasicPruner):
return (weight_tensor.detach() * grad.detach()).data.pow(2) return (weight_tensor.detach() * grad.detach()).data.pow(2)
def reset_tools(self): def reset_tools(self):
hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()} hook_targets = {name: wrapper.weight for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) collector_info = HookCollectorInfo(hook_targets, 'tensor', self._collector) # type: ignore
if self.data_collector is None: if self.data_collector is None:
self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion, self.data_collector = SingleHookTrainerBasedDataCollector(self, self.trainer, self.optimizer_helper, self.criterion,
1, collector_infos=[collector_info]) 1, collector_infos=[collector_info])
else: else:
self.data_collector.reset(collector_infos=[collector_info]) self.data_collector.reset(collector_infos=[collector_info]) # type: ignore
if self.metrics_calculator is None: if self.metrics_calculator is None:
self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0) self.metrics_calculator = MultiDataNormMetricsCalculator(p=1, dim=0)
if self.sparsity_allocator is None: if self.sparsity_allocator is None:
...@@ -1095,24 +1097,26 @@ class ADMMPruner(BasicPruner): ...@@ -1095,24 +1097,26 @@ class ADMMPruner(BasicPruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/admm_pruning_torch.py <examples/model_compress/pruning/admm_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/admm_pruning_torch.py <examples/model_compress/pruning/admm_pruning_torch.py>`
""" """
def __init__(self, model: Module, config_list: List[Dict], trainer: Callable[[Module, Optimizer, Callable], None], def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], trainer: Callable[[Module, Optimizer, Callable], None],
traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int, traced_optimizer: Traceable, criterion: Callable[[Tensor, Tensor], Tensor], iterations: int,
training_epochs: int, granularity: str = 'fine-grained'): training_epochs: int, granularity: str = 'fine-grained'):
self.trainer = trainer self.trainer = trainer
if isinstance(traced_optimizer, OptimizerConstructHelper): if isinstance(traced_optimizer, OptimizerConstructHelper):
self.optimizer_helper = traced_optimizer self.optimizer_helper = traced_optimizer
else: else:
assert model is not None, 'Model is required if traced_optimizer is provided.'
self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer) self.optimizer_helper = OptimizerConstructHelper.from_trace(model, traced_optimizer)
self.criterion = criterion self.criterion = criterion
self.iterations = iterations self.iterations = iterations
self.training_epochs = training_epochs self.training_epochs = training_epochs
assert granularity in ['fine-grained', 'coarse-grained'] assert granularity in ['fine-grained', 'coarse-grained']
self.granularity = granularity self.granularity = granularity
self.Z, self.U = {}, {}
super().__init__(model, config_list) super().__init__(model, config_list)
def reset(self, model: Optional[Module], config_list: Optional[List[Dict]]): def reset(self, model: Module, config_list: List[Dict]):
super().reset(model, config_list) super().reset(model, config_list)
self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} self.Z = {name: wrapper.module.weight.data.clone().detach() for name, wrapper in self.get_modules_wrapper().items()} # type: ignore
self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()} self.U = {name: torch.zeros_like(z).to(z.device) for name, z in self.Z.items()}
def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]): def _validate_config_before_canonical(self, model: Module, config_list: List[Dict]):
...@@ -1156,6 +1160,8 @@ class ADMMPruner(BasicPruner): ...@@ -1156,6 +1160,8 @@ class ADMMPruner(BasicPruner):
Tuple[Module, Dict] Tuple[Module, Dict]
Return the wrapped model and mask. Return the wrapped model and mask.
""" """
assert self.bound_model is not None
assert self.data_collector is not None and self.metrics_calculator is not None and self.sparsity_allocator is not None
for i in range(self.iterations): for i in range(self.iterations):
_logger.info('======= ADMM Iteration %d Start =======', i) _logger.info('======= ADMM Iteration %d Start =======', i)
data = self.data_collector.collect() data = self.data_collector.collect()
...@@ -1169,11 +1175,10 @@ class ADMMPruner(BasicPruner): ...@@ -1169,11 +1175,10 @@ class ADMMPruner(BasicPruner):
self.Z[name] = self.Z[name].mul(mask['weight']) self.Z[name] = self.Z[name].mul(mask['weight'])
self.U[name] = self.U[name] + data[name] - self.Z[name] self.U[name] = self.U[name] + data[name] - self.Z[name]
self.Z = None self.Z, self.U = {}, {}
self.U = None
torch.cuda.empty_cache() torch.cuda.empty_cache()
metrics = self.metrics_calculator.calculate_metrics(data) metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) masks = self.sparsity_allocator.generate_sparsity(metrics)
self.load_masks(masks) self.load_masks(masks)
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
from copy import deepcopy from copy import deepcopy
from typing import Dict, List, Tuple, Callable, Optional from typing import Dict, List, Tuple, Callable, Optional, Union
import torch import torch
from torch import Tensor from torch import Tensor
...@@ -36,8 +36,8 @@ class PruningScheduler(BasePruningScheduler): ...@@ -36,8 +36,8 @@ class PruningScheduler(BasePruningScheduler):
reset_weight reset_weight
If set True, the model weight will reset to the origin model weight at the end of each iteration step. If set True, the model weight will reset to the origin model weight at the end of each iteration step.
""" """
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Callable[[Module], None] = None, def __init__(self, pruner: Pruner, task_generator: TaskGenerator, finetuner: Optional[Callable[[Module], None]] = None,
speedup: bool = False, dummy_input: Tensor = None, evaluator: Optional[Callable[[Module], float]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None, evaluator: Optional[Callable[[Module], float]] = None,
reset_weight: bool = False): reset_weight: bool = False):
self.pruner = pruner self.pruner = pruner
self.task_generator = task_generator self.task_generator = task_generator
...@@ -155,5 +155,5 @@ class PruningScheduler(BasePruningScheduler): ...@@ -155,5 +155,5 @@ class PruningScheduler(BasePruningScheduler):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return result return result
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]: def get_best_result(self) -> Optional[Tuple[Union[int, str], Module, Dict[str, Dict[str, Tensor]], Optional[float], List[Dict]]]:
return self.task_generator.get_best_result() return self.task_generator.get_best_result()
...@@ -2,7 +2,8 @@ ...@@ -2,7 +2,8 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
from typing import Dict, List, Callable, Optional from pathlib import Path
from typing import Dict, List, Callable, Optional, Union
from torch import Tensor from torch import Tensor
from torch.nn import Module from torch.nn import Module
...@@ -293,9 +294,9 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -293,9 +294,9 @@ class SimulatedAnnealingPruner(IterativePruner):
Parameters Parameters
---------- ----------
model : Module model : Optional[Module]
The origin unwrapped pytorch model to be pruned. The origin unwrapped pytorch model to be pruned.
config_list : List[Dict] config_list : Optional[List[Dict]]
The origin config list provided by the user. The origin config list provided by the user.
evaluator : Callable[[Module], float] evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score. Evaluate the pruned model and give a score.
...@@ -312,7 +313,7 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -312,7 +313,7 @@ class SimulatedAnnealingPruner(IterativePruner):
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration. This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
pruning_params : Dict pruning_params : Dict
If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in. If the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
log_dir : str log_dir : Union[str, Path]
The log directory use to saving the result, you can find the best result under this folder. The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
...@@ -337,9 +338,9 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -337,9 +338,9 @@ class SimulatedAnnealingPruner(IterativePruner):
For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>` For detailed example please refer to :githublink:`examples/model_compress/pruning/simulated_anealing_pruning_torch.py <examples/model_compress/pruning/simulated_anealing_pruning_torch.py>`
""" """
def __init__(self, model: Module, config_list: List[Dict], evaluator: Callable[[Module], float], start_temperature: float = 100, def __init__(self, model: Optional[Module], config_list: Optional[List[Dict]], evaluator: Callable[[Module], float], start_temperature: float = 100,
stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35, stop_temperature: float = 20, cool_down_rate: float = 0.9, perturbation_magnitude: float = 0.35,
pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: str = '.', keep_intermediate_result: bool = False, pruning_algorithm: str = 'level', pruning_params: Dict = {}, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False,
finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None): finetuner: Optional[Callable[[Module], None]] = None, speedup: bool = False, dummy_input: Optional[Tensor] = None):
task_generator = SimulatedAnnealingTaskGenerator(origin_model=model, task_generator = SimulatedAnnealingTaskGenerator(origin_model=model,
origin_config_list=config_list, origin_config_list=config_list,
...@@ -350,7 +351,7 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -350,7 +351,7 @@ class SimulatedAnnealingPruner(IterativePruner):
log_dir=log_dir, log_dir=log_dir,
keep_intermediate_result=keep_intermediate_result) keep_intermediate_result=keep_intermediate_result)
if 'traced_optimizer' in pruning_params: if 'traced_optimizer' in pruning_params:
pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) pruning_params['traced_optimizer'] = OptimizerConstructHelper.from_trace(model, pruning_params['traced_optimizer']) # type: ignore
pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params) pruner = PRUNER_DICT[pruning_algorithm](None, None, **pruning_params)
super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input, super().__init__(pruner, task_generator, finetuner=finetuner, speedup=speedup, dummy_input=dummy_input,
evaluator=evaluator, reset_weight=False) evaluator=evaluator, reset_weight=False)
...@@ -7,7 +7,8 @@ from typing import Dict, List, Tuple, Callable ...@@ -7,7 +7,8 @@ from typing import Dict, List, Tuple, Callable
import torch import torch
from torch import autograd, Tensor from torch import autograd, Tensor
from torch.nn import Module, Parameter from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.optim import Optimizer, Adam from torch.optim import Optimizer, Adam
from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo from nni.algorithms.compression.v2.pytorch.base import PrunerModuleWrapper, LayerInfo
...@@ -41,15 +42,15 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper): ...@@ -41,15 +42,15 @@ class PrunerScoredModuleWrapper(PrunerModuleWrapper):
""" """
def __init__(self, module: Module, module_name: str, config: Dict): def __init__(self, module: Module, module_name: str, config: Dict):
super().__init__(module, module_name, config) super().__init__(module, module_name, config)
self.weight_score = Parameter(torch.empty(self.weight.size())) self.weight_score = Parameter(torch.empty(self.weight.size())) # type: ignore
torch.nn.init.constant_(self.weight_score, val=0.0) torch.nn.init.constant_(self.weight_score, val=0.0)
def forward(self, *inputs): def forward(self, *inputs):
# apply mask to weight, bias # apply mask to weight, bias
# NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach_()` # NOTE: I don't know why training getting slower and slower if only `self.weight_mask` without `detach()`
self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach_())) self.module.weight = torch.mul(self.weight, _StraightThrough.apply(self.weight_score, self.weight_mask.detach())) # type: ignore
if hasattr(self.module, 'bias') and self.module.bias is not None: if hasattr(self.module, 'bias') and self.module.bias is not None:
self.module.bias = torch.mul(self.bias, self.bias_mask) self.module.bias = torch.mul(self.bias, self.bias_mask) # type: ignore
return self.module(*inputs) return self.module(*inputs)
...@@ -58,7 +59,7 @@ class _StraightThrough(autograd.Function): ...@@ -58,7 +59,7 @@ class _StraightThrough(autograd.Function):
Straight through the gradient to the score, then the score = initial_score + sum(-lr * grad(weight) * weight). Straight through the gradient to the score, then the score = initial_score + sum(-lr * grad(weight) * weight).
""" """
@staticmethod @staticmethod
def forward(self, score, masks): def forward(ctx, score, masks):
return masks return masks
@staticmethod @staticmethod
...@@ -71,12 +72,13 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -71,12 +72,13 @@ class WeightScoreTrainerBasedDataCollector(TrainerBasedDataCollector):
Collect all weight_score in wrappers as data used to calculate metrics. Collect all weight_score in wrappers as data used to calculate metrics.
""" """
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs): for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
data = {} data = {}
for _, wrapper in self.compressor.get_modules_wrapper().items(): for _, wrapper in self.compressor.get_modules_wrapper().items():
data[wrapper.name] = wrapper.weight_score.data data[wrapper.name] = wrapper.weight_score.data # type: ignore
return data return data
...@@ -193,6 +195,7 @@ class MovementPruner(BasicPruner): ...@@ -193,6 +195,7 @@ class MovementPruner(BasicPruner):
self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False) self.sparsity_allocator = NormalSparsityAllocator(self, continuous_mask=False)
# use Adam to update the weight_score # use Adam to update the weight_score
assert self.bound_model is not None
params = [{"params": [p for n, p in self.bound_model.named_parameters() if "weight_score" in n and p.requires_grad]}] params = [{"params": [p for n, p in self.bound_model.named_parameters() if "weight_score" in n and p.requires_grad]}]
optimizer = Adam(params, 1e-2) optimizer = Adam(params, 1e-2)
self.step_counter = 0 self.step_counter = 0
...@@ -205,10 +208,10 @@ class MovementPruner(BasicPruner): ...@@ -205,10 +208,10 @@ class MovementPruner(BasicPruner):
if self.step_counter > self.warm_up_step: if self.step_counter > self.warm_up_step:
self.cubic_schedule(self.step_counter) self.cubic_schedule(self.step_counter)
data = {} data = {}
for _, wrapper in self.get_modules_wrapper().items(): for wrapper in self.get_modules_wrapper().values():
data[wrapper.name] = wrapper.weight_score.data data[wrapper.name] = wrapper.weight_score.data
metrics = self.metrics_calculator.calculate_metrics(data) metrics = self.metrics_calculator.calculate_metrics(data) # type: ignore
masks = self.sparsity_allocator.generate_sparsity(metrics) masks = self.sparsity_allocator.generate_sparsity(metrics) # type: ignore
self.load_masks(masks) self.load_masks(masks)
if self.data_collector is None: if self.data_collector is None:
...@@ -232,15 +235,15 @@ class MovementPruner(BasicPruner): ...@@ -232,15 +235,15 @@ class MovementPruner(BasicPruner):
wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config) wrapper = PrunerScoredModuleWrapper(layer.module, layer.name, config)
assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name assert hasattr(layer.module, 'weight'), "module %s does not have 'weight' attribute" % layer.name
# move newly registered buffers to the same device of weight # move newly registered buffers to the same device of weight
wrapper.to(layer.module.weight.device) wrapper.to(layer.module.weight.device) # type: ignore
return wrapper return wrapper
def compress(self) -> Tuple[Module, Dict]: def compress(self) -> Tuple[Module, Dict]:
# sparsity grow from 0 # sparsity grow from 0
for _, wrapper in self.get_modules_wrapper().items(): for wrapper in self.get_modules_wrapper().values():
wrapper.config['total_sparsity'] = 0 wrapper.config['total_sparsity'] = 0
result = super().compress() result = super().compress()
# del weight_score # del weight_score
for _, wrapper in self.get_modules_wrapper().items(): for wrapper in self.get_modules_wrapper().values():
wrapper.weight_score = None wrapper.weight_score = None
return result return result
...@@ -13,7 +13,7 @@ from torch import Tensor ...@@ -13,7 +13,7 @@ from torch import Tensor
from torch.nn import Module from torch.nn import Module
from torch.optim import Optimizer from torch.optim import Optimizer
from nni.algorithms.compression.v2.pytorch.base import Compressor, LayerInfo, Task, TaskResult from nni.algorithms.compression.v2.pytorch.base import Pruner, LayerInfo, Task, TaskResult
from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper from nni.algorithms.compression.v2.pytorch.utils import OptimizerConstructHelper
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -29,7 +29,7 @@ class DataCollector: ...@@ -29,7 +29,7 @@ class DataCollector:
The compressor binded with this DataCollector. The compressor binded with this DataCollector.
""" """
def __init__(self, compressor: Compressor): def __init__(self, compressor: Pruner):
self.compressor = compressor self.compressor = compressor
def reset(self): def reset(self):
...@@ -76,10 +76,10 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -76,10 +76,10 @@ class TrainerBasedDataCollector(DataCollector):
This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks. This class includes some trainer based util functions, i.e., patch optimizer or criterion, add hooks.
""" """
def __init__(self, compressor: Compressor, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper, def __init__(self, compressor: Pruner, trainer: Callable[[Module, Optimizer, Callable], None], optimizer_helper: OptimizerConstructHelper,
criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int, criterion: Callable[[Tensor, Tensor], Tensor], training_epochs: int,
opt_before_tasks: List = [], opt_after_tasks: List = [], opt_before_tasks: List = [], opt_after_tasks: List = [],
collector_infos: List[HookCollectorInfo] = [], criterion_patch: Callable[[Callable], Callable] = None): collector_infos: List[HookCollectorInfo] = [], criterion_patch: Optional[Callable[[Callable], Callable]] = None):
""" """
Parameters Parameters
---------- ----------
...@@ -165,6 +165,7 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -165,6 +165,7 @@ class TrainerBasedDataCollector(DataCollector):
def _reset_optimizer(self): def _reset_optimizer(self):
parameter_name_map = self.compressor.get_origin2wrapped_parameter_name_map() parameter_name_map = self.compressor.get_origin2wrapped_parameter_name_map()
assert self.compressor.bound_model is not None
self.optimizer = self.optimizer_helper.call(self.compressor.bound_model, parameter_name_map) self.optimizer = self.optimizer_helper.call(self.compressor.bound_model, parameter_name_map)
def _patch_optimizer(self): def _patch_optimizer(self):
...@@ -187,11 +188,11 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -187,11 +188,11 @@ class TrainerBasedDataCollector(DataCollector):
self._hook_buffer[self._hook_id] = {} self._hook_buffer[self._hook_id] = {}
if collector_info.hook_type == 'forward': if collector_info.hook_type == 'forward':
self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector) self._add_forward_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
elif collector_info.hook_type == 'backward': elif collector_info.hook_type == 'backward':
self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector) self._add_backward_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
elif collector_info.hook_type == 'tensor': elif collector_info.hook_type == 'tensor':
self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector) self._add_tensor_hook(self._hook_id, collector_info.targets, collector_info.collector) # type: ignore
else: else:
_logger.warning('Skip unsupported hook type: %s', collector_info.hook_type) _logger.warning('Skip unsupported hook type: %s', collector_info.hook_type)
...@@ -210,7 +211,7 @@ class TrainerBasedDataCollector(DataCollector): ...@@ -210,7 +211,7 @@ class TrainerBasedDataCollector(DataCollector):
assert all(isinstance(layer_info, LayerInfo) for layer_info in layers) assert all(isinstance(layer_info, LayerInfo) for layer_info in layers)
for layer in layers: for layer in layers:
self._hook_buffer[hook_id][layer.name] = [] self._hook_buffer[hook_id][layer.name] = []
handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name])) handle = layer.module.register_backward_hook(collector(self._hook_buffer[hook_id][layer.name])) # type: ignore
self._hook_handles[hook_id][layer.name] = handle self._hook_handles[hook_id][layer.name] = handle
def _add_tensor_hook(self, hook_id: int, tensors: Dict[str, Tensor], def _add_tensor_hook(self, hook_id: int, tensors: Dict[str, Tensor],
...@@ -286,7 +287,7 @@ class MetricsCalculator: ...@@ -286,7 +287,7 @@ class MetricsCalculator:
self.block_sparse_size = [1] * len(self.dim) self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None: if self.dim is not None:
assert all(i >= 0 for i in self.dim) assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
def calculate_metrics(self, data: Dict) -> Dict[str, Tensor]: def calculate_metrics(self, data: Dict) -> Dict[str, Tensor]:
""" """
...@@ -334,7 +335,7 @@ class SparsityAllocator: ...@@ -334,7 +335,7 @@ class SparsityAllocator:
Inherit the mask already in the wrapper if set True. Inherit the mask already in the wrapper if set True.
""" """
def __init__(self, pruner: Compressor, dim: Optional[Union[int, List[int]]] = None, def __init__(self, pruner: Pruner, dim: Optional[Union[int, List[int]]] = None,
block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True): block_sparse_size: Optional[Union[int, List[int]]] = None, continuous_mask: bool = True):
self.pruner = pruner self.pruner = pruner
self.dim = dim if not isinstance(dim, int) else [dim] self.dim = dim if not isinstance(dim, int) else [dim]
...@@ -345,7 +346,7 @@ class SparsityAllocator: ...@@ -345,7 +346,7 @@ class SparsityAllocator:
self.block_sparse_size = [1] * len(self.dim) self.block_sparse_size = [1] * len(self.dim)
if self.dim is not None: if self.dim is not None:
assert all(i >= 0 for i in self.dim) assert all(i >= 0 for i in self.dim)
self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) self.dim, self.block_sparse_size = (list(t) for t in zip(*sorted(zip(self.dim, self.block_sparse_size)))) # type: ignore
self.continuous_mask = continuous_mask self.continuous_mask = continuous_mask
def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]: def generate_sparsity(self, metrics: Dict) -> Dict[str, Dict[str, Tensor]]:
...@@ -384,7 +385,7 @@ class SparsityAllocator: ...@@ -384,7 +385,7 @@ class SparsityAllocator:
weight_mask = weight_mask.expand(expand_size).reshape(reshape_size) weight_mask = weight_mask.expand(expand_size).reshape(reshape_size)
wrapper = self.pruner.get_modules_wrapper()[name] wrapper = self.pruner.get_modules_wrapper()[name]
weight_size = wrapper.weight.data.size() weight_size = wrapper.weight.data.size() # type: ignore
if self.dim is None: if self.dim is None:
assert weight_mask.size() == weight_size assert weight_mask.size() == weight_size
...@@ -401,7 +402,7 @@ class SparsityAllocator: ...@@ -401,7 +402,7 @@ class SparsityAllocator:
expand_mask = {'weight': weight_mask.expand(weight_size).clone()} expand_mask = {'weight': weight_mask.expand(weight_size).clone()}
# NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence. # NOTE: assume we only mask output, so the mask and bias have a one-to-one correspondence.
# If we support more kind of masks, this place need refactor. # If we support more kind of masks, this place need refactor.
if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): if wrapper.bias_mask is not None and weight_mask.size() == wrapper.bias_mask.size(): # type: ignore
expand_mask['bias'] = weight_mask.clone() expand_mask['bias'] = weight_mask.clone()
return expand_mask return expand_mask
...@@ -463,7 +464,7 @@ class TaskGenerator: ...@@ -463,7 +464,7 @@ class TaskGenerator:
If keeping the intermediate result, including intermediate model and masks during each iteration. If keeping the intermediate result, including intermediate model and masks during each iteration.
""" """
def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {}, def __init__(self, origin_model: Optional[Module], origin_masks: Optional[Dict[str, Dict[str, Tensor]]] = {},
origin_config_list: Optional[List[Dict]] = [], log_dir: str = '.', keep_intermediate_result: bool = False): origin_config_list: Optional[List[Dict]] = [], log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
self._log_dir = log_dir self._log_dir = log_dir
self._keep_intermediate_result = keep_intermediate_result self._keep_intermediate_result = keep_intermediate_result
...@@ -486,7 +487,7 @@ class TaskGenerator: ...@@ -486,7 +487,7 @@ class TaskGenerator:
self._save_data('origin', model, masks, config_list) self._save_data('origin', model, masks, config_list)
self._task_id_candidate = 0 self._task_id_candidate = 0
self._tasks: Dict[int, Task] = {} self._tasks: Dict[Union[int, str], Task] = {}
self._pending_tasks: List[Task] = self.init_pending_tasks() self._pending_tasks: List[Task] = self.init_pending_tasks()
self._best_score = None self._best_score = None
...@@ -560,7 +561,7 @@ class TaskGenerator: ...@@ -560,7 +561,7 @@ class TaskGenerator:
self._dump_tasks_info() self._dump_tasks_info()
return task return task
def get_best_result(self) -> Optional[Tuple[int, Module, Dict[str, Dict[str, Tensor]], float, List[Dict]]]: def get_best_result(self) -> Optional[Tuple[Union[int, str], Module, Dict[str, Dict[str, Tensor]], Optional[float], List[Dict]]]:
""" """
Returns Returns
------- -------
......
...@@ -34,6 +34,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -34,6 +34,7 @@ class WeightTrainerBasedDataCollector(TrainerBasedDataCollector):
""" """
def collect(self) -> Dict[str, Tensor]: def collect(self) -> Dict[str, Tensor]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs): for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
...@@ -50,6 +51,7 @@ class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector): ...@@ -50,6 +51,7 @@ class SingleHookTrainerBasedDataCollector(TrainerBasedDataCollector):
""" """
def collect(self) -> Dict[str, List[Tensor]]: def collect(self) -> Dict[str, List[Tensor]]:
assert self.compressor.bound_model is not None
for _ in range(self.training_epochs): for _ in range(self.training_epochs):
self.trainer(self.compressor.bound_model, self.optimizer, self.criterion) self.trainer(self.compressor.bound_model, self.optimizer, self.criterion)
......
...@@ -70,7 +70,7 @@ class NormMetricsCalculator(MetricsCalculator): ...@@ -70,7 +70,7 @@ class NormMetricsCalculator(MetricsCalculator):
if len(across_dim) == 0: if len(across_dim) == 0:
metrics[name] = tensor.abs() metrics[name] = tensor.abs()
else: else:
metrics[name] = tensor.norm(p=self.p, dim=across_dim) metrics[name] = tensor.norm(p=self.p, dim=across_dim) # type: ignore
return metrics return metrics
...@@ -142,7 +142,7 @@ class DistMetricsCalculator(MetricsCalculator): ...@@ -142,7 +142,7 @@ class DistMetricsCalculator(MetricsCalculator):
if len(across_dim) == 0: if len(across_dim) == 0:
dist_sum = torch.abs(reorder_tensor - other).sum() dist_sum = torch.abs(reorder_tensor - other).sum()
else: else:
dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum() dist_sum = torch.norm((reorder_tensor - other), p=self.p, dim=across_dim).sum() # type: ignore
# NOTE: this place need refactor when support layer level pruning. # NOTE: this place need refactor when support layer level pruning.
tmp_metric = metric tmp_metric = metric
for i in idx[:-1]: for i in idx[:-1]:
......
...@@ -141,7 +141,7 @@ class DDPG(nn.Module): ...@@ -141,7 +141,7 @@ class DDPG(nn.Module):
]) ])
target_q_batch = to_tensor(reward_batch) + \ target_q_batch = to_tensor(reward_batch) + \
self.discount * to_tensor(terminal_batch.astype(np.float)) * next_q_values self.discount * to_tensor(terminal_batch.astype(np.float32)) * next_q_values
# Critic update # Critic update
self.critic.zero_grad() self.critic.zero_grad()
......
...@@ -38,8 +38,8 @@ class AMCEnv: ...@@ -38,8 +38,8 @@ class AMCEnv:
assert target in ['flops', 'params'] assert target in ['flops', 'params']
self.target = target self.target = target
self.origin_target, self.origin_params_num, self.origin_statistics = count_flops_params(model, dummy_input, verbose=False) self.origin_target, self.origin_params_num, origin_statistics = count_flops_params(model, dummy_input, verbose=False)
self.origin_statistics = {result['name']: result for result in self.origin_statistics} self.origin_statistics = {result['name']: result for result in origin_statistics}
self.under_pruning_target = sum([self.origin_statistics[name][self.target] for name in self.pruning_op_names]) self.under_pruning_target = sum([self.origin_statistics[name][self.target] for name in self.pruning_op_names])
self.excepted_pruning_target = self.total_sparsity * self.under_pruning_target self.excepted_pruning_target = self.total_sparsity * self.under_pruning_target
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
from __future__ import absolute_import from __future__ import absolute_import
from collections import deque, namedtuple from collections import deque, namedtuple
from typing import Any, List
import warnings import warnings
import random import random
...@@ -31,7 +32,7 @@ def sample_batch_indexes(low, high, size): ...@@ -31,7 +32,7 @@ def sample_batch_indexes(low, high, size):
'Not enough entries to sample without replacement. ' 'Not enough entries to sample without replacement. '
'Consider increasing your warm-up phase to avoid oversampling!') 'Consider increasing your warm-up phase to avoid oversampling!')
batch_idxs = np.random.random_integers(low, high - 1, size=size) batch_idxs = np.random.random_integers(low, high - 1, size=size)
assert len(batch_idxs) == size assert len(batch_idxs) == size # type: ignore
return batch_idxs return batch_idxs
...@@ -147,14 +148,14 @@ class SequentialMemory(Memory): ...@@ -147,14 +148,14 @@ class SequentialMemory(Memory):
# Skip this transition because the environment was reset here. Select a new, random # Skip this transition because the environment was reset here. Select a new, random
# transition and use this instead. This may cause the batch to contain the same # transition and use this instead. This may cause the batch to contain the same
# transition twice. # transition twice.
idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] idx = sample_batch_indexes(1, self.nb_entries, size=1)[0] # type: ignore
terminal0 = self.terminals[idx - 2] if idx >= 2 else False terminal0 = self.terminals[idx - 2] if idx >= 2 else False
assert 1 <= idx < self.nb_entries assert 1 <= idx < self.nb_entries
# This code is slightly complicated by the fact that subsequent observations might be # This code is slightly complicated by the fact that subsequent observations might be
# from different episodes. We ensure that an experience never spans multiple episodes. # from different episodes. We ensure that an experience never spans multiple episodes.
# This is probably not that important in practice but it seems cleaner. # This is probably not that important in practice but it seems cleaner.
state0 = [self.observations[idx - 1]] state0: List[Any] = [self.observations[idx - 1]]
for offset in range(0, self.window_length - 1): for offset in range(0, self.window_length - 1):
current_idx = idx - 2 - offset current_idx = idx - 2 - offset
current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False current_terminal = self.terminals[current_idx - 1] if current_idx - 1 > 0 else False
......
...@@ -29,7 +29,7 @@ class NormalSparsityAllocator(SparsityAllocator): ...@@ -29,7 +29,7 @@ class NormalSparsityAllocator(SparsityAllocator):
# We assume the metric value are all positive right now. # We assume the metric value are all positive right now.
metric = metrics[name] metric = metrics[name]
if self.continuous_mask: if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask) metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
prune_num = int(sparsity_rate * metric.numel()) prune_num = int(sparsity_rate * metric.numel())
if prune_num == 0: if prune_num == 0:
threshold = metric.min() - 1 threshold = metric.min() - 1
...@@ -64,7 +64,7 @@ class BankSparsityAllocator(SparsityAllocator): ...@@ -64,7 +64,7 @@ class BankSparsityAllocator(SparsityAllocator):
# We assume the metric value are all positive right now. # We assume the metric value are all positive right now.
metric = metrics[name] metric = metrics[name]
if self.continuous_mask: if self.continuous_mask:
metric *= self._compress_mask(wrapper.weight_mask) metric *= self._compress_mask(wrapper.weight_mask) # type: ignore
n_dim = len(metric.shape) n_dim = len(metric.shape)
assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric' assert n_dim >= len(self.balance_gran), 'Dimension of balance_gran should be smaller than metric'
# make up for balance_gran # make up for balance_gran
...@@ -129,15 +129,15 @@ class GlobalSparsityAllocator(SparsityAllocator): ...@@ -129,15 +129,15 @@ class GlobalSparsityAllocator(SparsityAllocator):
# We assume the metric value are all positive right now. # We assume the metric value are all positive right now.
if self.continuous_mask: if self.continuous_mask:
metric = metric * self._compress_mask(wrapper.weight_mask) metric = metric * self._compress_mask(wrapper.weight_mask) # type: ignore
layer_weight_num = wrapper.weight.data.numel() layer_weight_num = wrapper.weight.data.numel() # type: ignore
total_weight_num += layer_weight_num total_weight_num += layer_weight_num
expend_times = int(layer_weight_num / metric.numel()) expend_times = int(layer_weight_num / metric.numel())
retention_ratio = 1 - max_sparsity_per_layer.get(name, 1) retention_ratio = 1 - max_sparsity_per_layer.get(name, 1)
retention_numel = math.ceil(retention_ratio * layer_weight_num) retention_numel = math.ceil(retention_ratio * layer_weight_num)
removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel())) removed_metric_num = math.ceil(retention_numel / (wrapper.weight_mask.numel() / metric.numel())) # type: ignore
stay_metric_num = metric.numel() - removed_metric_num stay_metric_num = metric.numel() - removed_metric_num
if stay_metric_num <= 0: if stay_metric_num <= 0:
sub_thresholds[name] = metric.min().item() - 1 sub_thresholds[name] = metric.min().item() - 1
...@@ -182,7 +182,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator): ...@@ -182,7 +182,7 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
grouped_metric = {name: metrics[name] for name in names if name in metrics} grouped_metric = {name: metrics[name] for name in names if name in metrics}
if self.continuous_mask: if self.continuous_mask:
for name, metric in grouped_metric.items(): for name, metric in grouped_metric.items():
metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) metric *= self._compress_mask(self.pruner.get_modules_wrapper()[name].weight_mask) # type: ignore
if len(grouped_metric) > 0: if len(grouped_metric) > 0:
grouped_metrics[idx] = grouped_metric grouped_metrics[idx] = grouped_metric
for _, group_metric_dict in grouped_metrics.items(): for _, group_metric_dict in grouped_metrics.items():
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from copy import deepcopy from copy import deepcopy
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, List, Tuple from typing import Dict, List, Optional, Tuple, Union
import json_tricks import json_tricks
import numpy as np import numpy as np
...@@ -150,9 +150,9 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator): ...@@ -150,9 +150,9 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
class SimulatedAnnealingTaskGenerator(TaskGenerator): class SimulatedAnnealingTaskGenerator(TaskGenerator):
def __init__(self, origin_model: Module, origin_config_list: List[Dict], origin_masks: Dict[str, Dict[str, Tensor]] = {}, def __init__(self, origin_model: Optional[Module], origin_config_list: Optional[List[Dict]], origin_masks: Dict[str, Dict[str, Tensor]] = {},
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9, start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, log_dir: str = '.', keep_intermediate_result: bool = False): perturbation_magnitude: float = 0.35, log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False):
""" """
Parameters Parameters
---------- ----------
...@@ -196,9 +196,9 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -196,9 +196,9 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self.target_sparsity_list = config_list_canonical(model, config_list) self.target_sparsity_list = config_list_canonical(model, config_list)
self._adjust_target_sparsity() self._adjust_target_sparsity()
self._temp_config_list = None self._temp_config_list = []
self._current_sparsity_list = None self._current_sparsity_list = []
self._current_score = None self._current_score = 0.
super().reset(model, config_list=config_list, masks=masks) super().reset(model, config_list=config_list, masks=masks)
...@@ -248,7 +248,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -248,7 +248,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity return self._sparsity_to_config_list(rescaled_sparsity, config), rescaled_sparsity
def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> List: def _rescale_sparsity(self, random_sparsity: List, target_sparsity: float, op_names: List) -> Optional[List]:
assert len(random_sparsity) == len(op_names) assert len(random_sparsity) == len(op_names)
num_weights = sorted([self.weights_numel[op_name] for op_name in op_names]) num_weights = sorted([self.weights_numel[op_name] for op_name in op_names])
...@@ -267,7 +267,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -267,7 +267,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
scale = target_sparsity / (total_weights_pruned / total_weights) scale = target_sparsity / (total_weights_pruned / total_weights)
# rescale the sparsity # rescale the sparsity
sparsity = np.asarray(sparsity) * scale sparsity = list(np.asarray(sparsity) * scale)
return sparsity return sparsity
def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]: def _sparsity_to_config_list(self, sparsity: List, config: Dict) -> List[Dict]:
...@@ -285,7 +285,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -285,7 +285,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
# decrease magnitude with current temperature # decrease magnitude with current temperature
magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude magnitude = self.current_temperature / self.start_temperature * self.perturbation_magnitude
for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list): for config, current_sparsity in zip(self.target_sparsity_list, self._current_sparsity_list):
if len(current_sparsity) == 0: if not current_sparsity:
sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))] sub_temp_config_list = [deepcopy(config) for i in range(len(config['op_names']))]
for temp_config, op_name in zip(sub_temp_config_list, config['op_names']): for temp_config, op_name in zip(sub_temp_config_list, config['op_names']):
temp_config.update({'total_sparsity': 0, 'op_names': [op_name]}) temp_config.update({'total_sparsity': 0, 'op_names': [op_name]})
...@@ -327,11 +327,12 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator): ...@@ -327,11 +327,12 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
def generate_tasks(self, task_result: TaskResult) -> List[Task]: def generate_tasks(self, task_result: TaskResult) -> List[Task]:
# initial/update temp config list # initial/update temp config list
if self._temp_config_list is None: if not self._temp_config_list:
self._init_temp_config_list() self._init_temp_config_list()
else: else:
score = self._tasks[task_result.task_id].score score = self._tasks[task_result.task_id].score
if self._current_sparsity_list is None: assert score is not None, 'SimulatedAnnealingTaskGenerator need each score is not None.'
if not self._current_sparsity_list:
self._current_sparsity_list = deepcopy(self._temp_sparsity_list) self._current_sparsity_list = deepcopy(self._temp_sparsity_list)
self._current_score = score self._current_score = score
else: else:
......
...@@ -19,7 +19,7 @@ class ConstructHelper: ...@@ -19,7 +19,7 @@ class ConstructHelper:
def __init__(self, callable_obj: Callable, *args, **kwargs): def __init__(self, callable_obj: Callable, *args, **kwargs):
assert callable(callable_obj), '`callable_obj` must be a callable object.' assert callable(callable_obj), '`callable_obj` must be a callable object.'
self.callable_obj = callable_obj self.callable_obj = callable_obj
self.args = deepcopy(args) self.args = deepcopy(list(args))
self.kwargs = deepcopy(kwargs) self.kwargs = deepcopy(kwargs)
def call(self): def call(self):
......
...@@ -149,14 +149,14 @@ def compute_sparsity_compact2origin(origin_model: Module, compact_model: Module, ...@@ -149,14 +149,14 @@ def compute_sparsity_compact2origin(origin_model: Module, compact_model: Module,
continue continue
if 'op_names' in config and module_name not in config['op_names']: if 'op_names' in config and module_name not in config['op_names']:
continue continue
total_weight_num += module.weight.data.numel() total_weight_num += module.weight.data.numel() # type: ignore
for module_name, module in compact_model.named_modules(): for module_name, module in compact_model.named_modules():
module_type = type(module).__name__ module_type = type(module).__name__
if 'op_types' in config and module_type not in config['op_types']: if 'op_types' in config and module_type not in config['op_types']:
continue continue
if 'op_names' in config and module_name not in config['op_names']: if 'op_names' in config and module_name not in config['op_names']:
continue continue
left_weight_num += module.weight.data.numel() left_weight_num += module.weight.data.numel() # type: ignore
compact2origin_sparsity.append(deepcopy(config)) compact2origin_sparsity.append(deepcopy(config))
compact2origin_sparsity[-1]['total_sparsity'] = 1 - left_weight_num / total_weight_num compact2origin_sparsity[-1]['total_sparsity'] = 1 - left_weight_num / total_weight_num
return compact2origin_sparsity return compact2origin_sparsity
...@@ -179,7 +179,7 @@ def compute_sparsity_mask2compact(compact_model: Module, compact_model_masks: Di ...@@ -179,7 +179,7 @@ def compute_sparsity_mask2compact(compact_model: Module, compact_model_masks: Di
continue continue
if 'op_names' in config and module_name not in config['op_names']: if 'op_names' in config and module_name not in config['op_names']:
continue continue
module_weight_num = module.weight.data.numel() module_weight_num = module.weight.data.numel() # type: ignore
total_weight_num += module_weight_num total_weight_num += module_weight_num
if module_name in compact_model_masks: if module_name in compact_model_masks:
weight_mask = compact_model_masks[module_name]['weight'] weight_mask = compact_model_masks[module_name]['weight']
...@@ -229,7 +229,7 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_ ...@@ -229,7 +229,7 @@ def compute_sparsity(origin_model: Module, compact_model: Module, compact_model_
return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity return current2origin_sparsity, compact2origin_sparsity, mask2compact_sparsity
def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Dict: def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[str, Dict[str, Tensor]] = {}) -> Tuple[Dict[str, int], Dict[str, float]]:
""" """
Count the layer weight elements number in config_list. Count the layer weight elements number in config_list.
If masks is not empty, the masked weight will not be counted. If masks is not empty, the masked weight will not be counted.
...@@ -248,7 +248,7 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[ ...@@ -248,7 +248,7 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
masked_rate[module_name] = 1 - (weight_mask.sum().item() / weight_mask.numel()) masked_rate[module_name] = 1 - (weight_mask.sum().item() / weight_mask.numel())
model_weights_numel[module_name] = round(weight_mask.sum().item()) model_weights_numel[module_name] = round(weight_mask.sum().item())
else: else:
model_weights_numel[module_name] = module.weight.data.numel() model_weights_numel[module_name] = module.weight.data.numel() # type: ignore
return model_weights_numel, masked_rate return model_weights_numel, masked_rate
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment