Unverified Commit 6dfdc546 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression v2] Add optimizer & lr scheduler construct helper (#4332)

parent 7978c25a
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from copy import deepcopy
from typing import Callable, Dict, List, Type
from torch import Tensor
from torch.nn import Module
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
from nni.common.serializer import _trace_cls
from nni.common.serializer import Traceable
__all__ = ['OptimizerConstructHelper', 'LRSchedulerConstructHelper', 'trace_parameters']
def trace_parameters(base, kw_only=True):
if not isinstance(base, type):
raise Exception('Only class can be traced by this function.')
return _trace_cls(base, kw_only, call_super=False)
class ConstructHelper:
def __init__(self, callable_obj: Callable, *args, **kwargs):
assert callable(callable_obj), '`callable_obj` must be a callable object.'
self.callable_obj = callable_obj
self.args = deepcopy(args)
self.kwargs = deepcopy(kwargs)
def call(self):
args = deepcopy(self.args)
kwargs = deepcopy(self.kwargs)
return self.callable_obj(*args, **kwargs)
class OptimizerConstructHelper(ConstructHelper):
def __init__(self, model: Module, optimizer_class: Type[Optimizer], *args, **kwargs):
assert isinstance(model, Module), 'Only support pytorch module.'
assert issubclass(optimizer_class, Optimizer), 'Only support pytorch optimizer'
args = list(args)
if 'params' in kwargs:
kwargs['params'] = self.params2names(model, kwargs['params'])
else:
args[0] = self.params2names(model, args[0])
super().__init__(optimizer_class, *args, **kwargs)
def params2names(self, model: Module, params: List) -> List[Dict]:
param_groups = list(params)
assert len(param_groups) > 0
if not isinstance(param_groups[0], dict):
param_groups = [{'params': param_groups}]
for param_group in param_groups:
params = param_group['params']
if isinstance(params, Tensor):
params = [params]
elif isinstance(params, set):
raise TypeError('optimizer parameters need to be organized in ordered collections, but '
'the ordering of tensors in sets will change between runs. Please use a list instead.')
else:
params = list(params)
param_ids = [id(p) for p in params]
param_group['params'] = [name for name, p in model.named_parameters() if id(p) in param_ids]
return param_groups
def names2params(self, wrapped_model: Module, origin2wrapped_name_map: Dict, params: List[Dict]) -> List[Dict]:
param_groups = deepcopy(params)
for param_group in param_groups:
wrapped_names = [origin2wrapped_name_map.get(name, name) for name in param_group['params']]
param_group['params'] = [p for name, p in wrapped_model.named_parameters() if name in wrapped_names]
return param_groups
def call(self, wrapped_model: Module, origin2wrapped_name_map: Dict) -> Optimizer:
args = deepcopy(self.args)
kwargs = deepcopy(self.kwargs)
if 'params' in kwargs:
kwargs['params'] = self.names2params(wrapped_model, origin2wrapped_name_map, kwargs['params'])
else:
args[0] = self.names2params(wrapped_model, origin2wrapped_name_map, args[0])
return self.callable_obj(*args, **kwargs)
@staticmethod
def from_trace(model: Module, optimizer_trace: Traceable):
assert isinstance(optimizer_trace, Traceable), \
'Please use nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the optimizer class before initialize the optimizer.'
assert isinstance(optimizer_trace, Optimizer), \
'It is not an instance of torch.nn.Optimizer.'
return OptimizerConstructHelper(model,
optimizer_trace._get_nni_attr('symbol'),
*optimizer_trace._get_nni_attr('args'),
**optimizer_trace._get_nni_attr('kwargs'))
class LRSchedulerConstructHelper(ConstructHelper):
def __init__(self, lr_scheduler_class: Type[_LRScheduler], *args, **kwargs):
args = list(args)
if 'optimizer' in kwargs:
kwargs['optimizer'] = None
else:
args[0] = None
super().__init__(lr_scheduler_class, *args, **kwargs)
def call(self, optimizer: Optimizer) -> _LRScheduler:
args = deepcopy(self.args)
kwargs = deepcopy(self.kwargs)
if 'optimizer' in kwargs:
kwargs['optimizer'] = optimizer
else:
args[0] = optimizer
return self.callable_obj(*args, **kwargs)
@staticmethod
def from_trace(lr_scheduler_trace: Traceable):
assert isinstance(lr_scheduler_trace, Traceable), \
'Please use nni.algorithms.compression.v2.pytorch.utils.trace_parameters to wrap the lr scheduler class before initialize the scheduler.'
assert isinstance(lr_scheduler_trace, _LRScheduler), \
'It is not an instance of torch.nn.lr_scheduler._LRScheduler.'
return LRSchedulerConstructHelper(lr_scheduler_trace.trace_symbol,
*lr_scheduler_trace.trace_args,
**lr_scheduler_trace.trace_kwargs)
...@@ -344,7 +344,7 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme ...@@ -344,7 +344,7 @@ def load(string: Optional[str] = None, *, fp: Optional[Any] = None, ignore_comme
return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs) return json_tricks.load(fp, obj_pairs_hooks=hooks, **json_tricks_kwargs)
def _trace_cls(base, kw_only): def _trace_cls(base, kw_only, call_super=True):
# the implementation to trace a class is to store a copy of init arguments # the implementation to trace a class is to store a copy of init arguments
# this won't support class that defines a customized new but should work for most cases # this won't support class that defines a customized new but should work for most cases
...@@ -354,7 +354,7 @@ def _trace_cls(base, kw_only): ...@@ -354,7 +354,7 @@ def _trace_cls(base, kw_only):
args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True) args, kwargs = _formulate_arguments(base.__init__, args, kwargs, kw_only, is_class_init=True)
# calling serializable object init to initialize the full object # calling serializable object init to initialize the full object
super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=True) super().__init__(symbol=base, args=args, kwargs=kwargs, call_super=call_super)
_copy_class_wrapper_attributes(base, wrapper) _copy_class_wrapper_attributes(base, wrapper)
......
...@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import ( ...@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
AutoCompressPruner AutoCompressPruner
) )
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact, trace_parameters
class TorchModel(torch.nn.Module): class TorchModel(torch.nn.Module):
...@@ -52,7 +52,7 @@ def trainer(model, optimizer, criterion): ...@@ -52,7 +52,7 @@ def trainer(model, optimizer, criterion):
def get_optimizer(model): def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) return trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
...@@ -104,7 +104,7 @@ class IterativePrunerTestCase(unittest.TestCase): ...@@ -104,7 +104,7 @@ class IterativePrunerTestCase(unittest.TestCase):
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
admm_params = { admm_params = {
'trainer': trainer, 'trainer': trainer,
'optimizer': get_optimizer(model), 'traced_optimizer': get_optimizer(model),
'criterion': criterion, 'criterion': criterion,
'iterations': 10, 'iterations': 10,
'training_epochs': 1 'training_epochs': 1
......
...@@ -18,7 +18,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import ( ...@@ -18,7 +18,7 @@ from nni.algorithms.compression.v2.pytorch.pruning import (
ADMMPruner, ADMMPruner,
MovementPruner MovementPruner
) )
from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact from nni.algorithms.compression.v2.pytorch.utils import compute_sparsity_mask2compact, trace_parameters
class TorchModel(torch.nn.Module): class TorchModel(torch.nn.Module):
...@@ -55,7 +55,7 @@ def trainer(model, optimizer, criterion): ...@@ -55,7 +55,7 @@ def trainer(model, optimizer, criterion):
def get_optimizer(model): def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) return trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
...@@ -104,7 +104,7 @@ class PrunerTestCase(unittest.TestCase): ...@@ -104,7 +104,7 @@ class PrunerTestCase(unittest.TestCase):
def test_slim_pruner(self): def test_slim_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.8}] config_list = [{'op_types': ['BatchNorm2d'], 'total_sparsity': 0.8}]
pruner = SlimPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model), pruner = SlimPruner(model=model, config_list=config_list, trainer=trainer, traced_optimizer=get_optimizer(model),
criterion=criterion, training_epochs=1, scale=0.001, mode='global') criterion=criterion, training_epochs=1, scale=0.001, mode='global')
pruned_model, masks = pruner.compress() pruned_model, masks = pruner.compress()
pruner._unwrap_model() pruner._unwrap_model()
...@@ -115,7 +115,7 @@ class PrunerTestCase(unittest.TestCase): ...@@ -115,7 +115,7 @@ class PrunerTestCase(unittest.TestCase):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer, pruner = ActivationAPoZRankPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1, traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
activation='relu', mode='dependency_aware', activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28)) dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress() pruned_model, masks = pruner.compress()
...@@ -127,7 +127,7 @@ class PrunerTestCase(unittest.TestCase): ...@@ -127,7 +127,7 @@ class PrunerTestCase(unittest.TestCase):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=trainer, pruner = ActivationMeanRankPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1, traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
activation='relu', mode='dependency_aware', activation='relu', mode='dependency_aware',
dummy_input=torch.rand(10, 1, 28, 28)) dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress() pruned_model, masks = pruner.compress()
...@@ -139,7 +139,7 @@ class PrunerTestCase(unittest.TestCase): ...@@ -139,7 +139,7 @@ class PrunerTestCase(unittest.TestCase):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=trainer, pruner = TaylorFOWeightPruner(model=model, config_list=config_list, trainer=trainer,
optimizer=get_optimizer(model), criterion=criterion, training_batches=1, traced_optimizer=get_optimizer(model), criterion=criterion, training_batches=5,
mode='dependency_aware', dummy_input=torch.rand(10, 1, 28, 28)) mode='dependency_aware', dummy_input=torch.rand(10, 1, 28, 28))
pruned_model, masks = pruner.compress() pruned_model, masks = pruner.compress()
pruner._unwrap_model() pruner._unwrap_model()
...@@ -149,7 +149,7 @@ class PrunerTestCase(unittest.TestCase): ...@@ -149,7 +149,7 @@ class PrunerTestCase(unittest.TestCase):
def test_admm_pruner(self): def test_admm_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8, 'rho': 1e-3}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8, 'rho': 1e-3}]
pruner = ADMMPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model), pruner = ADMMPruner(model=model, config_list=config_list, trainer=trainer, traced_optimizer=get_optimizer(model),
criterion=criterion, iterations=2, training_epochs=1) criterion=criterion, iterations=2, training_epochs=1)
pruned_model, masks = pruner.compress() pruned_model, masks = pruner.compress()
pruner._unwrap_model() pruner._unwrap_model()
...@@ -159,7 +159,7 @@ class PrunerTestCase(unittest.TestCase): ...@@ -159,7 +159,7 @@ class PrunerTestCase(unittest.TestCase):
def test_movement_pruner(self): def test_movement_pruner(self):
model = TorchModel() model = TorchModel()
config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}] config_list = [{'op_types': ['Conv2d'], 'sparsity': 0.8}]
pruner = MovementPruner(model=model, config_list=config_list, trainer=trainer, optimizer=get_optimizer(model), pruner = MovementPruner(model=model, config_list=config_list, trainer=trainer, traced_optimizer=get_optimizer(model),
criterion=criterion, training_epochs=5, warm_up_step=0, cool_down_beginning_step=4) criterion=criterion, training_epochs=5, warm_up_step=0, cool_down_beginning_step=4)
pruned_model, masks = pruner.compress() pruned_model, masks = pruner.compress()
pruner._unwrap_model() pruner._unwrap_model()
......
...@@ -24,7 +24,8 @@ from nni.algorithms.compression.v2.pytorch.pruning.tools import ( ...@@ -24,7 +24,8 @@ from nni.algorithms.compression.v2.pytorch.pruning.tools import (
GlobalSparsityAllocator GlobalSparsityAllocator
) )
from nni.algorithms.compression.v2.pytorch.pruning.tools.base import HookCollectorInfo from nni.algorithms.compression.v2.pytorch.pruning.tools.base import HookCollectorInfo
from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name from nni.algorithms.compression.v2.pytorch.utils import get_module_by_name, trace_parameters
from nni.algorithms.compression.v2.pytorch.utils.constructor_helper import OptimizerConstructHelper
class TorchModel(torch.nn.Module): class TorchModel(torch.nn.Module):
...@@ -61,7 +62,7 @@ def trainer(model, optimizer, criterion): ...@@ -61,7 +62,7 @@ def trainer(model, optimizer, criterion):
def get_optimizer(model): def get_optimizer(model):
return torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) return trace_parameters(torch.optim.SGD)(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
...@@ -88,7 +89,8 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -88,7 +89,8 @@ class PruningToolsTestCase(unittest.TestCase):
model.conv1.module.weight.data = torch.ones(5, 1, 5, 5) model.conv1.module.weight.data = torch.ones(5, 1, 5, 5)
model.conv2.module.weight.data = torch.ones(10, 5, 5, 5) model.conv2.module.weight.data = torch.ones(10, 5, 5, 5)
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, get_optimizer(model), criterion, 1, opt_after_tasks=[opt_after]) optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = WeightTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 1, opt_after_tasks=[opt_after])
data = data_collector.collect() data = data_collector.collect()
assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2']) assert all(torch.equal(get_module_by_name(model, module_name)[1].module.weight.data, data[module_name]) for module_name in ['conv1', 'conv2'])
assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values()) assert all(t.numel() == (t == 1).type_as(t).sum().item() for t in data.values())
...@@ -102,7 +104,8 @@ class PruningToolsTestCase(unittest.TestCase): ...@@ -102,7 +104,8 @@ class PruningToolsTestCase(unittest.TestCase):
hook_targets = {'conv1': model.conv1.module.weight, 'conv2': model.conv2.module.weight} hook_targets = {'conv1': model.conv1.module.weight, 'conv2': model.conv2.module.weight}
collector_info = HookCollectorInfo(hook_targets, 'tensor', _collector) collector_info = HookCollectorInfo(hook_targets, 'tensor', _collector)
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, get_optimizer(model), criterion, 2, collector_infos=[collector_info]) optimizer_helper = OptimizerConstructHelper.from_trace(model, get_optimizer(model))
data_collector = SingleHookTrainerBasedDataCollector(pruner, trainer, optimizer_helper, criterion, 2, collector_infos=[collector_info])
data = data_collector.collect() data = data_collector.collect()
assert all(len(t) == 2 for t in data.values()) assert all(len(t) == 2 for t in data.values())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment