Unverified Commit bbf54a88 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Comporession] TransformersEvaluator (#5081)

parent 900be804
...@@ -10,3 +10,8 @@ LightningEvaluator ...@@ -10,3 +10,8 @@ LightningEvaluator
------------------ ------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator .. autoclass:: nni.compression.pytorch.LightningEvaluator
TransformersEvaluator
---------------------
.. autoclass:: nni.compression.pytorch.TransformersEvaluator
import numpy as np
from datasets import load_dataset, load_metric
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments
)
import nni
from nni.compression.pytorch import TransformersEvaluator
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
dataset = load_dataset('yelp_review_full')
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(1000))
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=5)
training_args = TrainingArguments(output_dir='test_trainer')
metric = load_metric('accuracy')
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
training_args = TrainingArguments(
output_dir='./log',
evaluation_strategy='epoch',
per_device_train_batch_size=32,
num_train_epochs=3,
max_steps=-1
)
trainer = nni.trace(Trainer)(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics
)
evaluator = TransformersEvaluator(trainer)
pruner = TaylorFOWeightPruner(model, [{'op_types': ['Linear'], 'sparsity': 0.5}], evaluator, 20)
_, masks = pruner.compress()
pruner.show_pruned_weights()
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .utils import LightningEvaluator, TorchEvaluator from .utils import Evaluator, LightningEvaluator, TorchEvaluator, TransformersEvaluator
...@@ -18,7 +18,7 @@ from nni.compression.pytorch.utils import count_flops_params ...@@ -18,7 +18,7 @@ from nni.compression.pytorch.utils import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv from .tools.rl_env import DDPG, AMCEnv
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical from ..utils import Evaluator, compute_sparsity, config_list_canonical
from ..utils.docstring import _EVALUATOR_DOCSTRING from ..utils.docstring import _EVALUATOR_DOCSTRING
...@@ -234,7 +234,7 @@ class AMCPruner(IterativePruner): ...@@ -234,7 +234,7 @@ class AMCPruner(IterativePruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: Evaluator,
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False, pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'): ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
... ...
......
...@@ -13,7 +13,7 @@ from torch.nn import Module ...@@ -13,7 +13,7 @@ from torch.nn import Module
from .basic_pruner import ADMMPruner from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator from .tools import LotteryTicketTaskGenerator
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper from ..utils import Evaluator, OptimizerConstructHelper
from ..utils.docstring import _EVALUATOR_DOCSTRING from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -82,7 +82,7 @@ class AutoCompressPruner(IterativePruner): ...@@ -82,7 +82,7 @@ class AutoCompressPruner(IterativePruner):
admm_params admm_params
The parameters passed to the ADMMPruner. The parameters passed to the ADMMPruner.
- evaluator : LightningEvaluator or TorchEvaluator. - evaluator : LightningEvaluator or TorchEvaluator or TransformersEvaluator.
The same with the evaluator of AutoCompressPruner input parameter. The same with the evaluator of AutoCompressPruner input parameter.
- iterations : int. - iterations : int.
The total iteration number in admm pruning algorithm. The total iteration number in admm pruning algorithm.
...@@ -92,7 +92,7 @@ class AutoCompressPruner(IterativePruner): ...@@ -92,7 +92,7 @@ class AutoCompressPruner(IterativePruner):
sa_params sa_params
The parameters passed to the SimulatedAnnealingPruner. The parameters passed to the SimulatedAnnealingPruner.
- evaluator : LightningEvaluator or TorchEvaluator. - evaluator : LightningEvaluator or TorchEvaluator or TransformersEvaluator.
The same with the evaluator of AutoCompressPruner input parameter. The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`. - start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process. Start temperature of the simulated annealing process.
...@@ -127,7 +127,7 @@ class AutoCompressPruner(IterativePruner): ...@@ -127,7 +127,7 @@ class AutoCompressPruner(IterativePruner):
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict, def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False, sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False): evaluator: Evaluator | None = None, speedup: bool = False):
... ...
@overload @overload
......
...@@ -53,8 +53,6 @@ from ..utils import ( ...@@ -53,8 +53,6 @@ from ..utils import (
OptimizerConstructHelper, OptimizerConstructHelper,
Scaling, Scaling,
Evaluator, Evaluator,
LightningEvaluator,
TorchEvaluator,
ForwardHook, ForwardHook,
TensorHook, TensorHook,
config_list_canonical config_list_canonical
...@@ -151,7 +149,7 @@ _LEGACY_CRITERION = Callable[[Tensor, Tensor], Tensor] ...@@ -151,7 +149,7 @@ _LEGACY_CRITERION = Callable[[Tensor, Tensor], Tensor]
# TODO: remove in nni v3.0. # TODO: remove in nni v3.0.
class EvaluatorBasedPruner(BasicPruner): class EvaluatorBasedPruner(BasicPruner):
evaluator: LightningEvaluator | TorchEvaluator evaluator: Evaluator
using_evaluator: bool using_evaluator: bool
trainer: _LEGACY_TRAINER trainer: _LEGACY_TRAINER
traced_optimizer: Optimizer traced_optimizer: Optimizer
...@@ -163,7 +161,7 @@ class EvaluatorBasedPruner(BasicPruner): ...@@ -163,7 +161,7 @@ class EvaluatorBasedPruner(BasicPruner):
# return the remaining arguments. # return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or (len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)): if (len(args) > 0 and isinstance(args[0], Evaluator)) or (len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, init_kwargs) init_kwargs = self._parse_args(new_api, args, kwargs, init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator') self.evaluator: Evaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete: if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True self.using_evaluator = True
...@@ -579,7 +577,7 @@ class SlimPruner(EvaluatorBasedPruner): ...@@ -579,7 +577,7 @@ class SlimPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator,
training_epochs: int, scale: float = 0.0001, mode='global'): training_epochs: int, scale: float = 0.0001, mode='global'):
... ...
...@@ -699,7 +697,7 @@ class ActivationPruner(EvaluatorBasedPruner): ...@@ -699,7 +697,7 @@ class ActivationPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_steps: int, def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int,
activation: str = 'relu', mode: str = 'normal', dummy_input: Optional[Tensor] = None): activation: str = 'relu', mode: str = 'normal', dummy_input: Optional[Tensor] = None):
... ...
...@@ -970,7 +968,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner): ...@@ -970,7 +968,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_steps: int, def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None): mode: str = 'normal', dummy_input: Optional[Tensor] = None):
... ...
...@@ -1114,7 +1112,7 @@ class ADMMPruner(EvaluatorBasedPruner): ...@@ -1114,7 +1112,7 @@ class ADMMPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, iterations: int, def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, iterations: int,
training_epochs: int, granularity: str = 'fine-grained'): training_epochs: int, granularity: str = 'fine-grained'):
... ...
......
...@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu ...@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from nni.compression.pytorch.speedup import ModelSpeedup from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator from .tools import TaskGenerator
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator from ..utils import Evaluator
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -25,7 +25,7 @@ _LEGACY_EVALUATOR = Callable[[Module], float] ...@@ -25,7 +25,7 @@ _LEGACY_EVALUATOR = Callable[[Module], float]
# TODO: remove in nni v3.0. # TODO: remove in nni v3.0.
class EvaluatorBasedPruningScheduler(BasePruningScheduler): class EvaluatorBasedPruningScheduler(BasePruningScheduler):
evaluator: LightningEvaluator | TorchEvaluator evaluator: Evaluator
using_evaluator: bool using_evaluator: bool
finetuner: _LEGACY_FINETUNER finetuner: _LEGACY_FINETUNER
_evaluator: _LEGACY_EVALUATOR _evaluator: _LEGACY_EVALUATOR
...@@ -38,7 +38,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler): ...@@ -38,7 +38,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \ if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)): (len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs) init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator') self.evaluator: Evaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete: if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True self.using_evaluator = True
...@@ -96,7 +96,7 @@ class PruningScheduler(EvaluatorBasedPruningScheduler): ...@@ -96,7 +96,7 @@ class PruningScheduler(EvaluatorBasedPruningScheduler):
""" """
@overload @overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator, def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: Evaluator,
speedup: bool = False, reset_weight: bool = False): speedup: bool = False, reset_weight: bool = False):
... ...
......
...@@ -30,8 +30,7 @@ from .tools import ( ...@@ -30,8 +30,7 @@ from .tools import (
) )
from ..utils import ( from ..utils import (
OptimizerConstructHelper, OptimizerConstructHelper,
LightningEvaluator, Evaluator
TorchEvaluator
) )
from ..utils.docstring import _EVALUATOR_DOCSTRING from ..utils.docstring import _EVALUATOR_DOCSTRING
...@@ -115,7 +114,7 @@ class LinearPruner(IterativePruner): ...@@ -115,7 +114,7 @@ class LinearPruner(IterativePruner):
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False, evaluator: Evaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}): pruning_params: Dict = {}):
... ...
...@@ -197,7 +196,7 @@ class AGPPruner(IterativePruner): ...@@ -197,7 +196,7 @@ class AGPPruner(IterativePruner):
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False, evaluator: Evaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}): pruning_params: Dict = {}):
... ...
...@@ -292,7 +291,7 @@ class LotteryTicketPruner(IterativePruner): ...@@ -292,7 +291,7 @@ class LotteryTicketPruner(IterativePruner):
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str, def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False, total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False, evaluator: Evaluator | None = None, speedup: bool = False,
reset_weight: bool = True, pruning_params: Dict = {}): reset_weight: bool = True, pruning_params: Dict = {}):
... ...
...@@ -386,7 +385,7 @@ class SimulatedAnnealingPruner(IterativePruner): ...@@ -386,7 +385,7 @@ class SimulatedAnnealingPruner(IterativePruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator,
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9, start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {}, perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False, speedup: bool = False): log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False, speedup: bool = False):
......
...@@ -27,8 +27,7 @@ from .tools import ( ...@@ -27,8 +27,7 @@ from .tools import (
) )
from ..utils import ( from ..utils import (
LightningEvaluator, Evaluator,
TorchEvaluator,
Scaling Scaling
) )
...@@ -188,7 +187,7 @@ class MovementPruner(EvaluatorBasedPruner): ...@@ -188,7 +187,7 @@ class MovementPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING) """.format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload @overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, warm_up_step: int, def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, warm_up_step: int,
cool_down_beginning_step: int, training_epochs: int | None = None, training_steps: int | None = None, cool_down_beginning_step: int, training_epochs: int | None = None, training_steps: int | None = None,
regular_scale: float | None = None, movement_mode: Literal['hard', 'soft'] = 'hard', regular_scale: float | None = None, movement_mode: Literal['hard', 'soft'] = 'hard',
sparse_granularity: Literal['auto', 'finegrained'] = 'finegrained'): sparse_granularity: Literal['auto', 'finegrained'] = 'finegrained'):
......
...@@ -14,6 +14,7 @@ from .evaluator import ( ...@@ -14,6 +14,7 @@ from .evaluator import (
Evaluator, Evaluator,
LightningEvaluator, LightningEvaluator,
TorchEvaluator, TorchEvaluator,
TransformersEvaluator,
Hook, Hook,
BackwardHook, BackwardHook,
ForwardHook, ForwardHook,
......
...@@ -18,10 +18,18 @@ try: ...@@ -18,10 +18,18 @@ try:
import pytorch_lightning as pl import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback from pytorch_lightning.callbacks import Callback
except ImportError: except ImportError:
LightingInstalled = False LIGHTNING_INSTALLED = False
else: else:
LightingInstalled = True LIGHTNING_INSTALLED = True
try:
from transformers.trainer import Trainer as HFTrainer
except ImportError:
TRANSFORMERS_INSTALLED = False
else:
TRANSFORMERS_INSTALLED = True
import nni
from nni.common import is_traceable from nni.common import is_traceable
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
...@@ -297,7 +305,7 @@ class LightningEvaluator(Evaluator): ...@@ -297,7 +305,7 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule, def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None): dummy_input: Any | None = None):
assert LightingInstalled, 'pytorch_lightning is not installed.' assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.' err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer') err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
...@@ -766,3 +774,175 @@ class TorchEvaluator(Evaluator): ...@@ -766,3 +774,175 @@ class TorchEvaluator(Evaluator):
def get_dummy_input(self) -> Any: def get_dummy_input(self) -> Any:
return self.dummy_input return self.dummy_input
class TransformersEvaluator(Evaluator):
"""
TransformersEvaluator is for the users who using Huggingface ``transformers.trainer.Trainer``.
Here is an example for using ``transformers.trainer.Trainer`` to initialize an evaluator:
.. code-block:: python
from transformers.trainer import Trainer
# wrap Trainer class with nni.trace
trainer = nni.trace(Trainer)(model=model)
evaluator = TransformersEvaluator(trainer)
# if you want to using customized optimizer & lr_scheduler, please also wrap Optimzier & _LRScheduler class
optimizer = nni.trace(Adam)(...)
lr_scheduler = nni.trace(LambdaLR)(...)
trainer = nni.trace(Trainer)(model=model, ..., optimizers=(optimizer, lr_scheduler))
evaluator = TransformersEvaluator(trainer)
Parameters
----------
trainer
``nni.trace(transformers.trainer.Trainer)`` instance. The trainer will be re-initialized inside evaluator,
so wrap with ``nni.trace`` is required for getting the initialization arguments.
dummy_input
Optional. The dummy_input is used to trace the graph, it's same with ``example_inputs`` in
`torch.jit.trace <https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace>`_.
"""
def __init__(self, trainer: HFTrainer, dummy_input: Any | None = None) -> None:
assert TRANSFORMERS_INSTALLED, 'transformers is not installed.'
assert is_traceable(trainer), f'Only support traced Trainer, please use nni.trace(Trainer) to initialize the trainer.'
self.traced_trainer = trainer
self.dummy_input = dummy_input
self.model: Module | None = None
self._ori_trainer_attr = {
'get_optimizer_cls_and_kwargs': HFTrainer.get_optimizer_cls_and_kwargs
}
self._initialization_complete = False
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'
if self.traced_trainer.optimizer is not None and is_traceable(self.traced_trainer.optimizer):
self._optimizer_helper = OptimizerConstructHelper.from_trace(pure_model, self.traced_trainer.optimizer)
else:
warn_msg = 'trainer.optimzer is not wrapped by nni.trace, or trainer.optimzer is None, ' + \
'will using huggingface default optimizer.'
_logger.warning(warn_msg)
self.traced_trainer.optimizer = None
def patched_get_optimizer_cls_and_kwargs(args) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = self._ori_trainer_attr['get_optimizer_cls_and_kwargs'](args)
return nni.trace(optimizer_cls), optimizer_kwargs
HFTrainer.get_optimizer_cls_and_kwargs = patched_get_optimizer_cls_and_kwargs
self._optimizer_helper = OptimizerConstructHelper.from_trace(pure_model, self.traced_trainer.create_optimizer())
HFTrainer.get_optimizer_cls_and_kwargs = self._ori_trainer_attr['get_optimizer_cls_and_kwargs']
self.traced_trainer.optimizer = None
if self.traced_trainer.lr_scheduler is not None and is_traceable(self.traced_trainer.lr_scheduler):
self._lr_scheduler_helper = LRSchedulerConstructHelper.from_trace(self.traced_trainer.lr_scheduler)
else:
warn_msg = 'trainer.lr_scheduler is not wrapped by nni.trace, or trainer.lr_scheduler is None, ' + \
'will using huggingface default lr_scheduler.'
_logger.warning(warn_msg)
self.traced_trainer.lr_scheduler = None
self._lr_scheduler_helper = None
self._initialization_complete = True
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, Module)
if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.')
self.unbind_model()
self.model = model
# re-initialized Trainer
args = list(self.traced_trainer.trace_args)
kwargs = dict()
kwargs.update(self.traced_trainer.trace_kwargs)
if len(args) != 0:
assert isinstance(args[0], Module) or args[0] is None
args[0] = self.model
else:
kwargs['model'] = self.model
self.trainer: HFTrainer = self.traced_trainer.trace_symbol(*args, **kwargs)
self._ori_trainer_attr['compute_loss'] = self.trainer.compute_loss
self._param_names_map = param_names_map
self.trainer.optimizer = self._optimizer_helper.call(self.model, self._param_names_map)
self._ori_trainer_attr['optimizer.step'] = self.trainer.optimizer.step
def unbind_model(self):
if self.model:
self.revert_loss()
self.revert_optimizer_step()
self.remove_all_hooks()
self._ori_trainer_attr.pop('optimizer.step', None)
self.trainer.optimizer = None
self._param_names_map = None
self._ori_trainer_attr.pop('compute_loss', None)
self.trainer = None
self.model = None
else:
_logger.warning('Did not bind any model, no need to unbind model.')
def patch_loss(self, patch: Callable[[Tensor], Tensor]):
old_compute_loss = self.trainer.compute_loss
def patched_compute_loss(_, model: Any, inputs: Any, return_outputs: bool = False):
result = old_compute_loss(model, inputs, return_outputs)
if return_outputs:
return patch(result[0]), result[1]
else:
return patch(result)
self.trainer.compute_loss = types.MethodType(patched_compute_loss, self.trainer)
def revert_loss(self):
self.trainer.compute_loss = self._ori_trainer_attr['compute_loss']
def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
assert self.trainer.optimizer is not None
old_step = self.trainer.optimizer.step
def patched_step(_, *args, **kwargs):
for task in before_step_tasks:
task()
# call origin optimizer step method
output = old_step(*args, **kwargs)
for task in after_step_tasks:
task()
return output
self.trainer.optimizer.step = types.MethodType(patched_step, self.trainer.optimizer)
def revert_optimizer_step(self):
assert self.trainer.optimizer is not None
self.trainer.optimizer.step = self._ori_trainer_attr['optimizer.step']
def train(self, max_steps: int | None = None, max_epochs: int | None = None):
assert self.model is not None
ori_steps, ori_epochs = self.trainer.args.max_steps, self.trainer.args.num_train_epochs
if max_epochs is not None:
self.trainer.args.num_train_epochs = max_epochs
if max_steps is not None:
self.trainer.args.max_steps = max_steps
self.trainer.lr_scheduler = self._lr_scheduler_helper.call(self.trainer.optimizer) if self._lr_scheduler_helper else None
self.trainer.train()
self.trainer.lr_scheduler = None
self.trainer.args.max_steps, self.trainer.args.num_train_epochs = ori_steps, ori_epochs
def finetune(self):
self.train()
def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]:
return self.trainer.evaluate()
def get_dummy_input(self) -> Any:
return self.dummy_input
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator, TransformersEvaluator
from .speedup import ModelSpeedup from .speedup import ModelSpeedup
from .compressor import Compressor, Pruner, Quantizer from .compressor import Compressor, Pruner, Quantizer
from .utils.apply_compression import apply_compression_results from .utils.apply_compression import apply_compression_results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment