Unverified Commit bbf54a88 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Comporession] TransformersEvaluator (#5081)

parent 900be804
......@@ -10,3 +10,8 @@ LightningEvaluator
------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator
TransformersEvaluator
---------------------
.. autoclass:: nni.compression.pytorch.TransformersEvaluator
import numpy as np
from datasets import load_dataset, load_metric
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments
)
import nni
from nni.compression.pytorch import TransformersEvaluator
from nni.compression.pytorch.pruning import TaylorFOWeightPruner
dataset = load_dataset('yelp_review_full')
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
def tokenize_function(examples):
return tokenizer(examples['text'], padding='max_length', truncation=True)
tokenized_datasets = dataset.map(tokenize_function, batched=True)
small_train_dataset = tokenized_datasets['train'].shuffle(seed=42).select(range(1000))
small_eval_dataset = tokenized_datasets['test'].shuffle(seed=42).select(range(1000))
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased', num_labels=5)
training_args = TrainingArguments(output_dir='test_trainer')
metric = load_metric('accuracy')
def compute_metrics(eval_pred):
logits, labels = eval_pred
predictions = np.argmax(logits, axis=-1)
return metric.compute(predictions=predictions, references=labels)
training_args = TrainingArguments(
output_dir='./log',
evaluation_strategy='epoch',
per_device_train_batch_size=32,
num_train_epochs=3,
max_steps=-1
)
trainer = nni.trace(Trainer)(
model=model,
args=training_args,
train_dataset=small_train_dataset,
eval_dataset=small_eval_dataset,
compute_metrics=compute_metrics
)
evaluator = TransformersEvaluator(trainer)
pruner = TaylorFOWeightPruner(model, [{'op_types': ['Linear'], 'sparsity': 0.5}], evaluator, 20)
_, masks = pruner.compress()
pruner.show_pruned_weights()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .utils import LightningEvaluator, TorchEvaluator
from .utils import Evaluator, LightningEvaluator, TorchEvaluator, TransformersEvaluator
......@@ -18,7 +18,7 @@ from nni.compression.pytorch.utils import count_flops_params
from .iterative_pruner import IterativePruner, PRUNER_DICT
from .tools import TaskGenerator
from .tools.rl_env import DDPG, AMCEnv
from ..utils import LightningEvaluator, TorchEvaluator, compute_sparsity, config_list_canonical
from ..utils import Evaluator, compute_sparsity, config_list_canonical
from ..utils.docstring import _EVALUATOR_DOCSTRING
......@@ -234,7 +234,7 @@ class AMCPruner(IterativePruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
def __init__(self, total_episode: int, model: Module, config_list: List[Dict], evaluator: Evaluator,
pruning_algorithm: str = 'l1', log_dir: str = '.', keep_intermediate_result: bool = False,
ddpg_params: dict = {}, pruning_params: dict = {}, target: str = 'flops'):
...
......
......@@ -13,7 +13,7 @@ from torch.nn import Module
from .basic_pruner import ADMMPruner
from .iterative_pruner import IterativePruner, SimulatedAnnealingPruner
from .tools import LotteryTicketTaskGenerator
from ..utils import LightningEvaluator, TorchEvaluator, OptimizerConstructHelper
from ..utils import Evaluator, OptimizerConstructHelper
from ..utils.docstring import _EVALUATOR_DOCSTRING
_logger = logging.getLogger(__name__)
......@@ -82,7 +82,7 @@ class AutoCompressPruner(IterativePruner):
admm_params
The parameters passed to the ADMMPruner.
- evaluator : LightningEvaluator or TorchEvaluator.
- evaluator : LightningEvaluator or TorchEvaluator or TransformersEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- iterations : int.
The total iteration number in admm pruning algorithm.
......@@ -92,7 +92,7 @@ class AutoCompressPruner(IterativePruner):
sa_params
The parameters passed to the SimulatedAnnealingPruner.
- evaluator : LightningEvaluator or TorchEvaluator.
- evaluator : LightningEvaluator or TorchEvaluator or TransformersEvaluator.
The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process.
......@@ -127,7 +127,7 @@ class AutoCompressPruner(IterativePruner):
@overload
def __init__(self, model: Module, config_list: List[Dict], total_iteration: int, admm_params: Dict,
sa_params: Dict, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False):
evaluator: Evaluator | None = None, speedup: bool = False):
...
@overload
......
......@@ -53,8 +53,6 @@ from ..utils import (
OptimizerConstructHelper,
Scaling,
Evaluator,
LightningEvaluator,
TorchEvaluator,
ForwardHook,
TensorHook,
config_list_canonical
......@@ -151,7 +149,7 @@ _LEGACY_CRITERION = Callable[[Tensor, Tensor], Tensor]
# TODO: remove in nni v3.0.
class EvaluatorBasedPruner(BasicPruner):
evaluator: LightningEvaluator | TorchEvaluator
evaluator: Evaluator
using_evaluator: bool
trainer: _LEGACY_TRAINER
traced_optimizer: Optimizer
......@@ -163,7 +161,7 @@ class EvaluatorBasedPruner(BasicPruner):
# return the remaining arguments.
if (len(args) > 0 and isinstance(args[0], Evaluator)) or (len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
self.evaluator: Evaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
......@@ -579,7 +577,7 @@ class SlimPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator,
training_epochs: int, scale: float = 0.0001, mode='global'):
...
......@@ -699,7 +697,7 @@ class ActivationPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_steps: int,
def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int,
activation: str = 'relu', mode: str = 'normal', dummy_input: Optional[Tensor] = None):
...
......@@ -970,7 +968,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, training_steps: int,
def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, training_steps: int,
mode: str = 'normal', dummy_input: Optional[Tensor] = None):
...
......@@ -1114,7 +1112,7 @@ class ADMMPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, iterations: int,
def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, iterations: int,
training_epochs: int, granularity: str = 'fine-grained'):
...
......
......@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from nni.compression.pytorch.speedup import ModelSpeedup
from .tools import TaskGenerator
from ..utils import Evaluator, LightningEvaluator, TorchEvaluator
from ..utils import Evaluator
_logger = logging.getLogger(__name__)
......@@ -25,7 +25,7 @@ _LEGACY_EVALUATOR = Callable[[Module], float]
# TODO: remove in nni v3.0.
class EvaluatorBasedPruningScheduler(BasePruningScheduler):
evaluator: LightningEvaluator | TorchEvaluator
evaluator: Evaluator
using_evaluator: bool
finetuner: _LEGACY_FINETUNER
_evaluator: _LEGACY_EVALUATOR
......@@ -38,7 +38,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
if (len(args) > 0 and isinstance(args[0], Evaluator)) or \
(len(args) == 0 and isinstance(kwargs.get('evaluator', None), Evaluator)):
init_kwargs = self._parse_args(new_api, args, kwargs, new_init_kwargs)
self.evaluator: LightningEvaluator | TorchEvaluator = init_kwargs.pop('evaluator')
self.evaluator: Evaluator = init_kwargs.pop('evaluator')
if not self.evaluator._initialization_complete:
self.evaluator._init_optimizer_helpers(model) # type: ignore
self.using_evaluator = True
......@@ -96,7 +96,7 @@ class PruningScheduler(EvaluatorBasedPruningScheduler):
"""
@overload
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: LightningEvaluator | TorchEvaluator,
def __init__(self, pruner: Pruner, task_generator: TaskGenerator, evaluator: Evaluator,
speedup: bool = False, reset_weight: bool = False):
...
......
......@@ -30,8 +30,7 @@ from .tools import (
)
from ..utils import (
OptimizerConstructHelper,
LightningEvaluator,
TorchEvaluator
Evaluator
)
from ..utils.docstring import _EVALUATOR_DOCSTRING
......@@ -115,7 +114,7 @@ class LinearPruner(IterativePruner):
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
evaluator: Evaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}):
...
......@@ -197,7 +196,7 @@ class AGPPruner(IterativePruner):
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
evaluator: Evaluator | None = None, speedup: bool = False,
pruning_params: Dict = {}):
...
......@@ -292,7 +291,7 @@ class LotteryTicketPruner(IterativePruner):
@overload
def __init__(self, model: Module, config_list: List[Dict], pruning_algorithm: str,
total_iteration: int, log_dir: str = '.', keep_intermediate_result: bool = False,
evaluator: LightningEvaluator | TorchEvaluator | None = None, speedup: bool = False,
evaluator: Evaluator | None = None, speedup: bool = False,
reset_weight: bool = True, pruning_params: Dict = {}):
...
......@@ -386,7 +385,7 @@ class SimulatedAnnealingPruner(IterativePruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator,
def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator,
start_temperature: float = 100, stop_temperature: float = 20, cool_down_rate: float = 0.9,
perturbation_magnitude: float = 0.35, pruning_algorithm: str = 'level', pruning_params: Dict = {},
log_dir: Union[str, Path] = '.', keep_intermediate_result: bool = False, speedup: bool = False):
......
......@@ -27,8 +27,7 @@ from .tools import (
)
from ..utils import (
LightningEvaluator,
TorchEvaluator,
Evaluator,
Scaling
)
......@@ -188,7 +187,7 @@ class MovementPruner(EvaluatorBasedPruner):
""".format(evaluator_docstring=_EVALUATOR_DOCSTRING)
@overload
def __init__(self, model: Module, config_list: List[Dict], evaluator: LightningEvaluator | TorchEvaluator, warm_up_step: int,
def __init__(self, model: Module, config_list: List[Dict], evaluator: Evaluator, warm_up_step: int,
cool_down_beginning_step: int, training_epochs: int | None = None, training_steps: int | None = None,
regular_scale: float | None = None, movement_mode: Literal['hard', 'soft'] = 'hard',
sparse_granularity: Literal['auto', 'finegrained'] = 'finegrained'):
......
......@@ -14,6 +14,7 @@ from .evaluator import (
Evaluator,
LightningEvaluator,
TorchEvaluator,
TransformersEvaluator,
Hook,
BackwardHook,
ForwardHook,
......
......@@ -18,10 +18,18 @@ try:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
except ImportError:
LightingInstalled = False
LIGHTNING_INSTALLED = False
else:
LightingInstalled = True
LIGHTNING_INSTALLED = True
try:
from transformers.trainer import Trainer as HFTrainer
except ImportError:
TRANSFORMERS_INSTALLED = False
else:
TRANSFORMERS_INSTALLED = True
import nni
from nni.common import is_traceable
from .constructor_helper import OptimizerConstructHelper, LRSchedulerConstructHelper
......@@ -297,7 +305,7 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
assert LightingInstalled, 'pytorch_lightning is not installed.'
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
......@@ -766,3 +774,175 @@ class TorchEvaluator(Evaluator):
def get_dummy_input(self) -> Any:
return self.dummy_input
class TransformersEvaluator(Evaluator):
"""
TransformersEvaluator is for the users who using Huggingface ``transformers.trainer.Trainer``.
Here is an example for using ``transformers.trainer.Trainer`` to initialize an evaluator:
.. code-block:: python
from transformers.trainer import Trainer
# wrap Trainer class with nni.trace
trainer = nni.trace(Trainer)(model=model)
evaluator = TransformersEvaluator(trainer)
# if you want to using customized optimizer & lr_scheduler, please also wrap Optimzier & _LRScheduler class
optimizer = nni.trace(Adam)(...)
lr_scheduler = nni.trace(LambdaLR)(...)
trainer = nni.trace(Trainer)(model=model, ..., optimizers=(optimizer, lr_scheduler))
evaluator = TransformersEvaluator(trainer)
Parameters
----------
trainer
``nni.trace(transformers.trainer.Trainer)`` instance. The trainer will be re-initialized inside evaluator,
so wrap with ``nni.trace`` is required for getting the initialization arguments.
dummy_input
Optional. The dummy_input is used to trace the graph, it's same with ``example_inputs`` in
`torch.jit.trace <https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace>`_.
"""
def __init__(self, trainer: HFTrainer, dummy_input: Any | None = None) -> None:
assert TRANSFORMERS_INSTALLED, 'transformers is not installed.'
assert is_traceable(trainer), f'Only support traced Trainer, please use nni.trace(Trainer) to initialize the trainer.'
self.traced_trainer = trainer
self.dummy_input = dummy_input
self.model: Module | None = None
self._ori_trainer_attr = {
'get_optimizer_cls_and_kwargs': HFTrainer.get_optimizer_cls_and_kwargs
}
self._initialization_complete = False
def _init_optimizer_helpers(self, pure_model: Module | pl.LightningModule):
assert self._initialization_complete is False, 'Evaluator initialization is already complete.'
if self.traced_trainer.optimizer is not None and is_traceable(self.traced_trainer.optimizer):
self._optimizer_helper = OptimizerConstructHelper.from_trace(pure_model, self.traced_trainer.optimizer)
else:
warn_msg = 'trainer.optimzer is not wrapped by nni.trace, or trainer.optimzer is None, ' + \
'will using huggingface default optimizer.'
_logger.warning(warn_msg)
self.traced_trainer.optimizer = None
def patched_get_optimizer_cls_and_kwargs(args) -> Tuple[Any, Any]:
optimizer_cls, optimizer_kwargs = self._ori_trainer_attr['get_optimizer_cls_and_kwargs'](args)
return nni.trace(optimizer_cls), optimizer_kwargs
HFTrainer.get_optimizer_cls_and_kwargs = patched_get_optimizer_cls_and_kwargs
self._optimizer_helper = OptimizerConstructHelper.from_trace(pure_model, self.traced_trainer.create_optimizer())
HFTrainer.get_optimizer_cls_and_kwargs = self._ori_trainer_attr['get_optimizer_cls_and_kwargs']
self.traced_trainer.optimizer = None
if self.traced_trainer.lr_scheduler is not None and is_traceable(self.traced_trainer.lr_scheduler):
self._lr_scheduler_helper = LRSchedulerConstructHelper.from_trace(self.traced_trainer.lr_scheduler)
else:
warn_msg = 'trainer.lr_scheduler is not wrapped by nni.trace, or trainer.lr_scheduler is None, ' + \
'will using huggingface default lr_scheduler.'
_logger.warning(warn_msg)
self.traced_trainer.lr_scheduler = None
self._lr_scheduler_helper = None
self._initialization_complete = True
def bind_model(self, model: Module | pl.LightningModule, param_names_map: Dict[str, str] | None = None):
err_msg = 'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert self._initialization_complete is True, err_msg
assert isinstance(model, Module)
if self.model is not None:
_logger.warning('Already bound a model, will unbind it before bind a new model.')
self.unbind_model()
self.model = model
# re-initialized Trainer
args = list(self.traced_trainer.trace_args)
kwargs = dict()
kwargs.update(self.traced_trainer.trace_kwargs)
if len(args) != 0:
assert isinstance(args[0], Module) or args[0] is None
args[0] = self.model
else:
kwargs['model'] = self.model
self.trainer: HFTrainer = self.traced_trainer.trace_symbol(*args, **kwargs)
self._ori_trainer_attr['compute_loss'] = self.trainer.compute_loss
self._param_names_map = param_names_map
self.trainer.optimizer = self._optimizer_helper.call(self.model, self._param_names_map)
self._ori_trainer_attr['optimizer.step'] = self.trainer.optimizer.step
def unbind_model(self):
if self.model:
self.revert_loss()
self.revert_optimizer_step()
self.remove_all_hooks()
self._ori_trainer_attr.pop('optimizer.step', None)
self.trainer.optimizer = None
self._param_names_map = None
self._ori_trainer_attr.pop('compute_loss', None)
self.trainer = None
self.model = None
else:
_logger.warning('Did not bind any model, no need to unbind model.')
def patch_loss(self, patch: Callable[[Tensor], Tensor]):
old_compute_loss = self.trainer.compute_loss
def patched_compute_loss(_, model: Any, inputs: Any, return_outputs: bool = False):
result = old_compute_loss(model, inputs, return_outputs)
if return_outputs:
return patch(result[0]), result[1]
else:
return patch(result)
self.trainer.compute_loss = types.MethodType(patched_compute_loss, self.trainer)
def revert_loss(self):
self.trainer.compute_loss = self._ori_trainer_attr['compute_loss']
def patch_optimizer_step(self, before_step_tasks: List[Callable], after_step_tasks: List[Callable]):
assert self.trainer.optimizer is not None
old_step = self.trainer.optimizer.step
def patched_step(_, *args, **kwargs):
for task in before_step_tasks:
task()
# call origin optimizer step method
output = old_step(*args, **kwargs)
for task in after_step_tasks:
task()
return output
self.trainer.optimizer.step = types.MethodType(patched_step, self.trainer.optimizer)
def revert_optimizer_step(self):
assert self.trainer.optimizer is not None
self.trainer.optimizer.step = self._ori_trainer_attr['optimizer.step']
def train(self, max_steps: int | None = None, max_epochs: int | None = None):
assert self.model is not None
ori_steps, ori_epochs = self.trainer.args.max_steps, self.trainer.args.num_train_epochs
if max_epochs is not None:
self.trainer.args.num_train_epochs = max_epochs
if max_steps is not None:
self.trainer.args.max_steps = max_steps
self.trainer.lr_scheduler = self._lr_scheduler_helper.call(self.trainer.optimizer) if self._lr_scheduler_helper else None
self.trainer.train()
self.trainer.lr_scheduler = None
self.trainer.args.max_steps, self.trainer.args.num_train_epochs = ori_steps, ori_epochs
def finetune(self):
self.train()
def evaluate(self) -> float | None | Tuple[float, Dict[str, Any]] | Tuple[None, Dict[str, Any]]:
return self.trainer.evaluate()
def get_dummy_input(self) -> Any:
return self.dummy_input
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator
from nni.algorithms.compression.v2.pytorch import TorchEvaluator, LightningEvaluator, TransformersEvaluator
from .speedup import ModelSpeedup
from .compressor import Compressor, Pruner, Quantizer
from .utils.apply_compression import apply_compression_results
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment