Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bbf54a88
Unverified
Commit
bbf54a88
authored
Aug 26, 2022
by
J-shang
Committed by
GitHub
Aug 26, 2022
Browse files
[Comporession] TransformersEvaluator (#5081)
parent
900be804
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
271 additions
and
32 deletions
+271
-32
docs/source/reference/compression/evaluator.rst
docs/source/reference/compression/evaluator.rst
+5
-0
examples/model_compress/pruning/taylorfo_transformers_evaluator.py
...model_compress/pruning/taylorfo_transformers_evaluator.py
+57
-0
nni/algorithms/compression/v2/pytorch/__init__.py
nni/algorithms/compression/v2/pytorch/__init__.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
+2
-2
nni/algorithms/compression/v2/pytorch/pruning/auto_compress_pruner.py
...ms/compression/v2/pytorch/pruning/auto_compress_pruner.py
+4
-4
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+6
-8
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
...orithms/compression/v2/pytorch/pruning/basic_scheduler.py
+4
-4
nni/algorithms/compression/v2/pytorch/pruning/iterative_pruner.py
...rithms/compression/v2/pytorch/pruning/iterative_pruner.py
+5
-6
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
...orithms/compression/v2/pytorch/pruning/movement_pruner.py
+2
-3
nni/algorithms/compression/v2/pytorch/utils/__init__.py
nni/algorithms/compression/v2/pytorch/utils/__init__.py
+1
-0
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
+183
-3
nni/compression/pytorch/__init__.py
nni/compression/pytorch/__init__.py
+1
-1
No files found.
docs/source/reference/compression/evaluator.rst
View file @
bbf54a88
...
@@ -10,3 +10,8 @@ LightningEvaluator
...
@@ -10,3 +10,8 @@ LightningEvaluator
------------------
------------------
.. autoclass:: nni.compression.pytorch.LightningEvaluator
.. autoclass:: nni.compression.pytorch.LightningEvaluator
TransformersEvaluator
---------------------
.. autoclass:: nni.compression.pytorch.TransformersEvaluator
examples/model_compress/pruning/taylorfo_transformers_evaluator.py
0 → 100644
View file @
bbf54a88
import
numpy
as
np
from
datasets
import
load_dataset
,
load_metric
from
transformers
import
(
AutoTokenizer
,
AutoModelForSequenceClassification
,
Trainer
,
TrainingArguments
)
import
nni
from
nni.compression.pytorch
import
TransformersEvaluator
from
nni.compression.pytorch.pruning
import
TaylorFOWeightPruner
dataset
=
load_dataset
(
'yelp_review_full'
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
'bert-base-cased'
)
def
tokenize_function
(
examples
):
return
tokenizer
(
examples
[
'text'
],
padding
=
'max_length'
,
truncation
=
True
)
tokenized_datasets
=
dataset
.
map
(
tokenize_function
,
batched
=
True
)
small_train_dataset
=
tokenized_datasets
[
'train'
].
shuffle
(
seed
=
42
).
select
(
range
(
1000
))
small_eval_dataset
=
tokenized_datasets
[
'test'
].
shuffle
(
seed
=
42
).
select
(
range
(
1000
))
model
=
AutoModelForSequenceClassification
.
from_pretrained
(
'bert-base-cased'
,
num_labels
=
5
)
training_args
=
TrainingArguments
(
output_dir
=
'test_trainer'
)
metric
=
load_metric
(
'accuracy'
)
def
compute_metrics
(
eval_pred
):
logits
,
labels
=
eval_pred
predictions
=
np
.
argmax
(
logits
,
axis
=-
1
)
return
metric
.
compute
(
predictions
=
predictions
,
references
=
labels
)
training_args
=
TrainingArguments
(
output_dir
=
'./log'
,
evaluation_strategy
=
'epoch'
,
per_device_train_batch_size
=
32
,
num_train_epochs
=
3
,
max_steps
=-
1
)
trainer
=
nni
.
trace
(
Trainer
)(
model
=
model
,
args
=
training_args
,
train_dataset
=
small_train_dataset
,
eval_dataset
=
small_eval_dataset
,
compute_metrics
=
compute_metrics
)
evaluator
=
TransformersEvaluator
(
trainer
)
pruner
=
TaylorFOWeightPruner
(
model
,
[{
'op_types'
:
[
'Linear'
],
'sparsity'
:
0.5
}],
evaluator
,
20
)
_
,
masks
=
pruner
.
compress
()
pruner
.
show_pruned_weights
()
nni/algorithms/compression/v2/pytorch/__init__.py
View file @
bbf54a88
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.utils
import
LightningEvaluator
,
TorchEvaluator
from
.utils
import
Evaluator
,
LightningEvaluator
,
TorchEvaluator
,
TransformersEvaluator
nni/algorithms/compression/v2/pytorch/pruning/amc_pruner.py
View file @
bbf54a88
...
@@ -18,7 +18,7 @@ from nni.compression.pytorch.utils import count_flops_params
...
@@ -18,7 +18,7 @@ from nni.compression.pytorch.utils import count_flops_params
from
.iterative_pruner
import
IterativePruner
,
PRUNER_DICT
from
.iterative_pruner
import
IterativePruner
,
PRUNER_DICT
from
.tools
import
TaskGenerator
from
.tools
import
TaskGenerator
from
.tools.rl_env
import
DDPG
,
AMCEnv
from
.tools.rl_env
import
DDPG
,
AMCEnv
from
..utils
import
LightningEvaluator
,
Torch
Evaluator
,
compute_sparsity
,
config_list_canonical
from
..utils
import
Evaluator
,
compute_sparsity
,
config_list_canonical
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
...
@@ -234,7 +234,7 @@ class AMCPruner(IterativePruner):
...
@@ -234,7 +234,7 @@ class AMCPruner(IterativePruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
total_episode
:
int
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
def
__init__
(
self
,
total_episode
:
int
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
pruning_algorithm
:
str
=
'l1'
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
pruning_algorithm
:
str
=
'l1'
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
ddpg_params
:
dict
=
{},
pruning_params
:
dict
=
{},
target
:
str
=
'flops'
):
ddpg_params
:
dict
=
{},
pruning_params
:
dict
=
{},
target
:
str
=
'flops'
):
...
...
...
...
nni/algorithms/compression/v2/pytorch/pruning/auto_compress_pruner.py
View file @
bbf54a88
...
@@ -13,7 +13,7 @@ from torch.nn import Module
...
@@ -13,7 +13,7 @@ from torch.nn import Module
from
.basic_pruner
import
ADMMPruner
from
.basic_pruner
import
ADMMPruner
from
.iterative_pruner
import
IterativePruner
,
SimulatedAnnealingPruner
from
.iterative_pruner
import
IterativePruner
,
SimulatedAnnealingPruner
from
.tools
import
LotteryTicketTaskGenerator
from
.tools
import
LotteryTicketTaskGenerator
from
..utils
import
LightningEvaluator
,
Torch
Evaluator
,
OptimizerConstructHelper
from
..utils
import
Evaluator
,
OptimizerConstructHelper
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -82,7 +82,7 @@ class AutoCompressPruner(IterativePruner):
...
@@ -82,7 +82,7 @@ class AutoCompressPruner(IterativePruner):
admm_params
admm_params
The parameters passed to the ADMMPruner.
The parameters passed to the ADMMPruner.
- evaluator : LightningEvaluator or TorchEvaluator.
- evaluator : LightningEvaluator or TorchEvaluator
or TransformersEvaluator
.
The same with the evaluator of AutoCompressPruner input parameter.
The same with the evaluator of AutoCompressPruner input parameter.
- iterations : int.
- iterations : int.
The total iteration number in admm pruning algorithm.
The total iteration number in admm pruning algorithm.
...
@@ -92,7 +92,7 @@ class AutoCompressPruner(IterativePruner):
...
@@ -92,7 +92,7 @@ class AutoCompressPruner(IterativePruner):
sa_params
sa_params
The parameters passed to the SimulatedAnnealingPruner.
The parameters passed to the SimulatedAnnealingPruner.
- evaluator : LightningEvaluator or TorchEvaluator.
- evaluator : LightningEvaluator or TorchEvaluator
or TransformersEvaluator
.
The same with the evaluator of AutoCompressPruner input parameter.
The same with the evaluator of AutoCompressPruner input parameter.
- start_temperature : float. Default: `100`.
- start_temperature : float. Default: `100`.
Start temperature of the simulated annealing process.
Start temperature of the simulated annealing process.
...
@@ -127,7 +127,7 @@ class AutoCompressPruner(IterativePruner):
...
@@ -127,7 +127,7 @@ class AutoCompressPruner(IterativePruner):
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
total_iteration
:
int
,
admm_params
:
Dict
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
total_iteration
:
int
,
admm_params
:
Dict
,
sa_params
:
Dict
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
sa_params
:
Dict
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
evaluator
:
LightningEvaluator
|
Torch
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
):
evaluator
:
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
):
...
...
@
overload
@
overload
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
bbf54a88
...
@@ -53,8 +53,6 @@ from ..utils import (
...
@@ -53,8 +53,6 @@ from ..utils import (
OptimizerConstructHelper
,
OptimizerConstructHelper
,
Scaling
,
Scaling
,
Evaluator
,
Evaluator
,
LightningEvaluator
,
TorchEvaluator
,
ForwardHook
,
ForwardHook
,
TensorHook
,
TensorHook
,
config_list_canonical
config_list_canonical
...
@@ -151,7 +149,7 @@ _LEGACY_CRITERION = Callable[[Tensor, Tensor], Tensor]
...
@@ -151,7 +149,7 @@ _LEGACY_CRITERION = Callable[[Tensor, Tensor], Tensor]
# TODO: remove in nni v3.0.
# TODO: remove in nni v3.0.
class
EvaluatorBasedPruner
(
BasicPruner
):
class
EvaluatorBasedPruner
(
BasicPruner
):
evaluator
:
LightningEvaluator
|
Torch
Evaluator
evaluator
:
Evaluator
using_evaluator
:
bool
using_evaluator
:
bool
trainer
:
_LEGACY_TRAINER
trainer
:
_LEGACY_TRAINER
traced_optimizer
:
Optimizer
traced_optimizer
:
Optimizer
...
@@ -163,7 +161,7 @@ class EvaluatorBasedPruner(BasicPruner):
...
@@ -163,7 +161,7 @@ class EvaluatorBasedPruner(BasicPruner):
# return the remaining arguments.
# return the remaining arguments.
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
Evaluator
))
or
(
len
(
args
)
==
0
and
isinstance
(
kwargs
.
get
(
'evaluator'
,
None
),
Evaluator
)):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
Evaluator
))
or
(
len
(
args
)
==
0
and
isinstance
(
kwargs
.
get
(
'evaluator'
,
None
),
Evaluator
)):
init_kwargs
=
self
.
_parse_args
(
new_api
,
args
,
kwargs
,
init_kwargs
)
init_kwargs
=
self
.
_parse_args
(
new_api
,
args
,
kwargs
,
init_kwargs
)
self
.
evaluator
:
LightningEvaluator
|
Torch
Evaluator
=
init_kwargs
.
pop
(
'evaluator'
)
self
.
evaluator
:
Evaluator
=
init_kwargs
.
pop
(
'evaluator'
)
if
not
self
.
evaluator
.
_initialization_complete
:
if
not
self
.
evaluator
.
_initialization_complete
:
self
.
evaluator
.
_init_optimizer_helpers
(
model
)
# type: ignore
self
.
evaluator
.
_init_optimizer_helpers
(
model
)
# type: ignore
self
.
using_evaluator
=
True
self
.
using_evaluator
=
True
...
@@ -579,7 +577,7 @@ class SlimPruner(EvaluatorBasedPruner):
...
@@ -579,7 +577,7 @@ class SlimPruner(EvaluatorBasedPruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
training_epochs
:
int
,
scale
:
float
=
0.0001
,
mode
=
'global'
):
training_epochs
:
int
,
scale
:
float
=
0.0001
,
mode
=
'global'
):
...
...
...
@@ -699,7 +697,7 @@ class ActivationPruner(EvaluatorBasedPruner):
...
@@ -699,7 +697,7 @@ class ActivationPruner(EvaluatorBasedPruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
training_steps
:
int
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
training_steps
:
int
,
activation
:
str
=
'relu'
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
activation
:
str
=
'relu'
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
...
...
...
@@ -970,7 +968,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner):
...
@@ -970,7 +968,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
training_steps
:
int
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
training_steps
:
int
,
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
mode
:
str
=
'normal'
,
dummy_input
:
Optional
[
Tensor
]
=
None
):
...
...
...
@@ -1114,7 +1112,7 @@ class ADMMPruner(EvaluatorBasedPruner):
...
@@ -1114,7 +1112,7 @@ class ADMMPruner(EvaluatorBasedPruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
iterations
:
int
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
iterations
:
int
,
training_epochs
:
int
,
granularity
:
str
=
'fine-grained'
):
training_epochs
:
int
,
granularity
:
str
=
'fine-grained'
):
...
...
...
...
nni/algorithms/compression/v2/pytorch/pruning/basic_scheduler.py
View file @
bbf54a88
...
@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
...
@@ -15,7 +15,7 @@ from nni.algorithms.compression.v2.pytorch.base import Pruner, BasePruningSchedu
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
.tools
import
TaskGenerator
from
.tools
import
TaskGenerator
from
..utils
import
Evaluator
,
LightningEvaluator
,
TorchEvaluator
from
..utils
import
Evaluator
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -25,7 +25,7 @@ _LEGACY_EVALUATOR = Callable[[Module], float]
...
@@ -25,7 +25,7 @@ _LEGACY_EVALUATOR = Callable[[Module], float]
# TODO: remove in nni v3.0.
# TODO: remove in nni v3.0.
class
EvaluatorBasedPruningScheduler
(
BasePruningScheduler
):
class
EvaluatorBasedPruningScheduler
(
BasePruningScheduler
):
evaluator
:
LightningEvaluator
|
Torch
Evaluator
evaluator
:
Evaluator
using_evaluator
:
bool
using_evaluator
:
bool
finetuner
:
_LEGACY_FINETUNER
finetuner
:
_LEGACY_FINETUNER
_evaluator
:
_LEGACY_EVALUATOR
_evaluator
:
_LEGACY_EVALUATOR
...
@@ -38,7 +38,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
...
@@ -38,7 +38,7 @@ class EvaluatorBasedPruningScheduler(BasePruningScheduler):
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
Evaluator
))
or
\
if
(
len
(
args
)
>
0
and
isinstance
(
args
[
0
],
Evaluator
))
or
\
(
len
(
args
)
==
0
and
isinstance
(
kwargs
.
get
(
'evaluator'
,
None
),
Evaluator
)):
(
len
(
args
)
==
0
and
isinstance
(
kwargs
.
get
(
'evaluator'
,
None
),
Evaluator
)):
init_kwargs
=
self
.
_parse_args
(
new_api
,
args
,
kwargs
,
new_init_kwargs
)
init_kwargs
=
self
.
_parse_args
(
new_api
,
args
,
kwargs
,
new_init_kwargs
)
self
.
evaluator
:
LightningEvaluator
|
Torch
Evaluator
=
init_kwargs
.
pop
(
'evaluator'
)
self
.
evaluator
:
Evaluator
=
init_kwargs
.
pop
(
'evaluator'
)
if
not
self
.
evaluator
.
_initialization_complete
:
if
not
self
.
evaluator
.
_initialization_complete
:
self
.
evaluator
.
_init_optimizer_helpers
(
model
)
# type: ignore
self
.
evaluator
.
_init_optimizer_helpers
(
model
)
# type: ignore
self
.
using_evaluator
=
True
self
.
using_evaluator
=
True
...
@@ -96,7 +96,7 @@ class PruningScheduler(EvaluatorBasedPruningScheduler):
...
@@ -96,7 +96,7 @@ class PruningScheduler(EvaluatorBasedPruningScheduler):
"""
"""
@
overload
@
overload
def
__init__
(
self
,
pruner
:
Pruner
,
task_generator
:
TaskGenerator
,
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
def
__init__
(
self
,
pruner
:
Pruner
,
task_generator
:
TaskGenerator
,
evaluator
:
Evaluator
,
speedup
:
bool
=
False
,
reset_weight
:
bool
=
False
):
speedup
:
bool
=
False
,
reset_weight
:
bool
=
False
):
...
...
...
...
nni/algorithms/compression/v2/pytorch/pruning/iterative_pruner.py
View file @
bbf54a88
...
@@ -30,8 +30,7 @@ from .tools import (
...
@@ -30,8 +30,7 @@ from .tools import (
)
)
from
..utils
import
(
from
..utils
import
(
OptimizerConstructHelper
,
OptimizerConstructHelper
,
LightningEvaluator
,
Evaluator
TorchEvaluator
)
)
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
from
..utils.docstring
import
_EVALUATOR_DOCSTRING
...
@@ -115,7 +114,7 @@ class LinearPruner(IterativePruner):
...
@@ -115,7 +114,7 @@ class LinearPruner(IterativePruner):
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
evaluator
:
LightningEvaluator
|
Torch
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
,
evaluator
:
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
,
pruning_params
:
Dict
=
{}):
pruning_params
:
Dict
=
{}):
...
...
...
@@ -197,7 +196,7 @@ class AGPPruner(IterativePruner):
...
@@ -197,7 +196,7 @@ class AGPPruner(IterativePruner):
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
evaluator
:
LightningEvaluator
|
Torch
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
,
evaluator
:
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
,
pruning_params
:
Dict
=
{}):
pruning_params
:
Dict
=
{}):
...
...
...
@@ -292,7 +291,7 @@ class LotteryTicketPruner(IterativePruner):
...
@@ -292,7 +291,7 @@ class LotteryTicketPruner(IterativePruner):
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
evaluator
:
LightningEvaluator
|
Torch
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
,
evaluator
:
Evaluator
|
None
=
None
,
speedup
:
bool
=
False
,
reset_weight
:
bool
=
True
,
pruning_params
:
Dict
=
{}):
reset_weight
:
bool
=
True
,
pruning_params
:
Dict
=
{}):
...
...
...
@@ -386,7 +385,7 @@ class SimulatedAnnealingPruner(IterativePruner):
...
@@ -386,7 +385,7 @@ class SimulatedAnnealingPruner(IterativePruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
start_temperature
:
float
=
100
,
stop_temperature
:
float
=
20
,
cool_down_rate
:
float
=
0.9
,
start_temperature
:
float
=
100
,
stop_temperature
:
float
=
20
,
cool_down_rate
:
float
=
0.9
,
perturbation_magnitude
:
float
=
0.35
,
pruning_algorithm
:
str
=
'level'
,
pruning_params
:
Dict
=
{},
perturbation_magnitude
:
float
=
0.35
,
pruning_algorithm
:
str
=
'level'
,
pruning_params
:
Dict
=
{},
log_dir
:
Union
[
str
,
Path
]
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
speedup
:
bool
=
False
):
log_dir
:
Union
[
str
,
Path
]
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
speedup
:
bool
=
False
):
...
...
nni/algorithms/compression/v2/pytorch/pruning/movement_pruner.py
View file @
bbf54a88
...
@@ -27,8 +27,7 @@ from .tools import (
...
@@ -27,8 +27,7 @@ from .tools import (
)
)
from
..utils
import
(
from
..utils
import
(
LightningEvaluator
,
Evaluator
,
TorchEvaluator
,
Scaling
Scaling
)
)
...
@@ -188,7 +187,7 @@ class MovementPruner(EvaluatorBasedPruner):
...
@@ -188,7 +187,7 @@ class MovementPruner(EvaluatorBasedPruner):
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
"""
.
format
(
evaluator_docstring
=
_EVALUATOR_DOCSTRING
)
@
overload
@
overload
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
LightningEvaluator
|
Torch
Evaluator
,
warm_up_step
:
int
,
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
evaluator
:
Evaluator
,
warm_up_step
:
int
,
cool_down_beginning_step
:
int
,
training_epochs
:
int
|
None
=
None
,
training_steps
:
int
|
None
=
None
,
cool_down_beginning_step
:
int
,
training_epochs
:
int
|
None
=
None
,
training_steps
:
int
|
None
=
None
,
regular_scale
:
float
|
None
=
None
,
movement_mode
:
Literal
[
'hard'
,
'soft'
]
=
'hard'
,
regular_scale
:
float
|
None
=
None
,
movement_mode
:
Literal
[
'hard'
,
'soft'
]
=
'hard'
,
sparse_granularity
:
Literal
[
'auto'
,
'finegrained'
]
=
'finegrained'
):
sparse_granularity
:
Literal
[
'auto'
,
'finegrained'
]
=
'finegrained'
):
...
...
nni/algorithms/compression/v2/pytorch/utils/__init__.py
View file @
bbf54a88
...
@@ -14,6 +14,7 @@ from .evaluator import (
...
@@ -14,6 +14,7 @@ from .evaluator import (
Evaluator
,
Evaluator
,
LightningEvaluator
,
LightningEvaluator
,
TorchEvaluator
,
TorchEvaluator
,
TransformersEvaluator
,
Hook
,
Hook
,
BackwardHook
,
BackwardHook
,
ForwardHook
,
ForwardHook
,
...
...
nni/algorithms/compression/v2/pytorch/utils/evaluator.py
View file @
bbf54a88
...
@@ -18,10 +18,18 @@ try:
...
@@ -18,10 +18,18 @@ try:
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
from
pytorch_lightning.callbacks
import
Callback
from
pytorch_lightning.callbacks
import
Callback
except
ImportError
:
except
ImportError
:
L
ightingInstalled
=
False
L
IGHTNING_INSTALLED
=
False
else
:
else
:
L
ightingInstalled
=
True
L
IGHTNING_INSTALLED
=
True
try
:
from
transformers.trainer
import
Trainer
as
HFTrainer
except
ImportError
:
TRANSFORMERS_INSTALLED
=
False
else
:
TRANSFORMERS_INSTALLED
=
True
import
nni
from
nni.common
import
is_traceable
from
nni.common
import
is_traceable
from
.constructor_helper
import
OptimizerConstructHelper
,
LRSchedulerConstructHelper
from
.constructor_helper
import
OptimizerConstructHelper
,
LRSchedulerConstructHelper
...
@@ -297,7 +305,7 @@ class LightningEvaluator(Evaluator):
...
@@ -297,7 +305,7 @@ class LightningEvaluator(Evaluator):
def
__init__
(
self
,
trainer
:
pl
.
Trainer
,
data_module
:
pl
.
LightningDataModule
,
def
__init__
(
self
,
trainer
:
pl
.
Trainer
,
data_module
:
pl
.
LightningDataModule
,
dummy_input
:
Any
|
None
=
None
):
dummy_input
:
Any
|
None
=
None
):
assert
L
ightingInstalled
,
'pytorch_lightning is not installed.'
assert
L
IGHTNING_INSTALLED
,
'pytorch_lightning is not installed.'
err_msg_p
=
'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg_p
=
'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg
=
err_msg_p
.
format
(
'pytorch_lightning.Trainer'
,
'pytorch_lightning.Trainer'
)
err_msg
=
err_msg_p
.
format
(
'pytorch_lightning.Trainer'
,
'pytorch_lightning.Trainer'
)
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
),
err_msg
assert
isinstance
(
trainer
,
pl
.
Trainer
)
and
is_traceable
(
trainer
),
err_msg
...
@@ -766,3 +774,175 @@ class TorchEvaluator(Evaluator):
...
@@ -766,3 +774,175 @@ class TorchEvaluator(Evaluator):
def
get_dummy_input
(
self
)
->
Any
:
def
get_dummy_input
(
self
)
->
Any
:
return
self
.
dummy_input
return
self
.
dummy_input
class
TransformersEvaluator
(
Evaluator
):
"""
TransformersEvaluator is for the users who using Huggingface ``transformers.trainer.Trainer``.
Here is an example for using ``transformers.trainer.Trainer`` to initialize an evaluator:
.. code-block:: python
from transformers.trainer import Trainer
# wrap Trainer class with nni.trace
trainer = nni.trace(Trainer)(model=model)
evaluator = TransformersEvaluator(trainer)
# if you want to using customized optimizer & lr_scheduler, please also wrap Optimzier & _LRScheduler class
optimizer = nni.trace(Adam)(...)
lr_scheduler = nni.trace(LambdaLR)(...)
trainer = nni.trace(Trainer)(model=model, ..., optimizers=(optimizer, lr_scheduler))
evaluator = TransformersEvaluator(trainer)
Parameters
----------
trainer
``nni.trace(transformers.trainer.Trainer)`` instance. The trainer will be re-initialized inside evaluator,
so wrap with ``nni.trace`` is required for getting the initialization arguments.
dummy_input
Optional. The dummy_input is used to trace the graph, it's same with ``example_inputs`` in
`torch.jit.trace <https://pytorch.org/docs/stable/generated/torch.jit.trace.html?highlight=torch%20jit%20trace#torch.jit.trace>`_.
"""
def
__init__
(
self
,
trainer
:
HFTrainer
,
dummy_input
:
Any
|
None
=
None
)
->
None
:
assert
TRANSFORMERS_INSTALLED
,
'transformers is not installed.'
assert
is_traceable
(
trainer
),
f
'Only support traced Trainer, please use nni.trace(Trainer) to initialize the trainer.'
self
.
traced_trainer
=
trainer
self
.
dummy_input
=
dummy_input
self
.
model
:
Module
|
None
=
None
self
.
_ori_trainer_attr
=
{
'get_optimizer_cls_and_kwargs'
:
HFTrainer
.
get_optimizer_cls_and_kwargs
}
self
.
_initialization_complete
=
False
def
_init_optimizer_helpers
(
self
,
pure_model
:
Module
|
pl
.
LightningModule
):
assert
self
.
_initialization_complete
is
False
,
'Evaluator initialization is already complete.'
if
self
.
traced_trainer
.
optimizer
is
not
None
and
is_traceable
(
self
.
traced_trainer
.
optimizer
):
self
.
_optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
pure_model
,
self
.
traced_trainer
.
optimizer
)
else
:
warn_msg
=
'trainer.optimzer is not wrapped by nni.trace, or trainer.optimzer is None, '
+
\
'will using huggingface default optimizer.'
_logger
.
warning
(
warn_msg
)
self
.
traced_trainer
.
optimizer
=
None
def
patched_get_optimizer_cls_and_kwargs
(
args
)
->
Tuple
[
Any
,
Any
]:
optimizer_cls
,
optimizer_kwargs
=
self
.
_ori_trainer_attr
[
'get_optimizer_cls_and_kwargs'
](
args
)
return
nni
.
trace
(
optimizer_cls
),
optimizer_kwargs
HFTrainer
.
get_optimizer_cls_and_kwargs
=
patched_get_optimizer_cls_and_kwargs
self
.
_optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
pure_model
,
self
.
traced_trainer
.
create_optimizer
())
HFTrainer
.
get_optimizer_cls_and_kwargs
=
self
.
_ori_trainer_attr
[
'get_optimizer_cls_and_kwargs'
]
self
.
traced_trainer
.
optimizer
=
None
if
self
.
traced_trainer
.
lr_scheduler
is
not
None
and
is_traceable
(
self
.
traced_trainer
.
lr_scheduler
):
self
.
_lr_scheduler_helper
=
LRSchedulerConstructHelper
.
from_trace
(
self
.
traced_trainer
.
lr_scheduler
)
else
:
warn_msg
=
'trainer.lr_scheduler is not wrapped by nni.trace, or trainer.lr_scheduler is None, '
+
\
'will using huggingface default lr_scheduler.'
_logger
.
warning
(
warn_msg
)
self
.
traced_trainer
.
lr_scheduler
=
None
self
.
_lr_scheduler_helper
=
None
self
.
_initialization_complete
=
True
def
bind_model
(
self
,
model
:
Module
|
pl
.
LightningModule
,
param_names_map
:
Dict
[
str
,
str
]
|
None
=
None
):
err_msg
=
'Evaluator initialization is not complete, please call `_init_optimizer_helpers` before bind model.'
assert
self
.
_initialization_complete
is
True
,
err_msg
assert
isinstance
(
model
,
Module
)
if
self
.
model
is
not
None
:
_logger
.
warning
(
'Already bound a model, will unbind it before bind a new model.'
)
self
.
unbind_model
()
self
.
model
=
model
# re-initialized Trainer
args
=
list
(
self
.
traced_trainer
.
trace_args
)
kwargs
=
dict
()
kwargs
.
update
(
self
.
traced_trainer
.
trace_kwargs
)
if
len
(
args
)
!=
0
:
assert
isinstance
(
args
[
0
],
Module
)
or
args
[
0
]
is
None
args
[
0
]
=
self
.
model
else
:
kwargs
[
'model'
]
=
self
.
model
self
.
trainer
:
HFTrainer
=
self
.
traced_trainer
.
trace_symbol
(
*
args
,
**
kwargs
)
self
.
_ori_trainer_attr
[
'compute_loss'
]
=
self
.
trainer
.
compute_loss
self
.
_param_names_map
=
param_names_map
self
.
trainer
.
optimizer
=
self
.
_optimizer_helper
.
call
(
self
.
model
,
self
.
_param_names_map
)
self
.
_ori_trainer_attr
[
'optimizer.step'
]
=
self
.
trainer
.
optimizer
.
step
def
unbind_model
(
self
):
if
self
.
model
:
self
.
revert_loss
()
self
.
revert_optimizer_step
()
self
.
remove_all_hooks
()
self
.
_ori_trainer_attr
.
pop
(
'optimizer.step'
,
None
)
self
.
trainer
.
optimizer
=
None
self
.
_param_names_map
=
None
self
.
_ori_trainer_attr
.
pop
(
'compute_loss'
,
None
)
self
.
trainer
=
None
self
.
model
=
None
else
:
_logger
.
warning
(
'Did not bind any model, no need to unbind model.'
)
def
patch_loss
(
self
,
patch
:
Callable
[[
Tensor
],
Tensor
]):
old_compute_loss
=
self
.
trainer
.
compute_loss
def
patched_compute_loss
(
_
,
model
:
Any
,
inputs
:
Any
,
return_outputs
:
bool
=
False
):
result
=
old_compute_loss
(
model
,
inputs
,
return_outputs
)
if
return_outputs
:
return
patch
(
result
[
0
]),
result
[
1
]
else
:
return
patch
(
result
)
self
.
trainer
.
compute_loss
=
types
.
MethodType
(
patched_compute_loss
,
self
.
trainer
)
def
revert_loss
(
self
):
self
.
trainer
.
compute_loss
=
self
.
_ori_trainer_attr
[
'compute_loss'
]
def
patch_optimizer_step
(
self
,
before_step_tasks
:
List
[
Callable
],
after_step_tasks
:
List
[
Callable
]):
assert
self
.
trainer
.
optimizer
is
not
None
old_step
=
self
.
trainer
.
optimizer
.
step
def
patched_step
(
_
,
*
args
,
**
kwargs
):
for
task
in
before_step_tasks
:
task
()
# call origin optimizer step method
output
=
old_step
(
*
args
,
**
kwargs
)
for
task
in
after_step_tasks
:
task
()
return
output
self
.
trainer
.
optimizer
.
step
=
types
.
MethodType
(
patched_step
,
self
.
trainer
.
optimizer
)
def
revert_optimizer_step
(
self
):
assert
self
.
trainer
.
optimizer
is
not
None
self
.
trainer
.
optimizer
.
step
=
self
.
_ori_trainer_attr
[
'optimizer.step'
]
def
train
(
self
,
max_steps
:
int
|
None
=
None
,
max_epochs
:
int
|
None
=
None
):
assert
self
.
model
is
not
None
ori_steps
,
ori_epochs
=
self
.
trainer
.
args
.
max_steps
,
self
.
trainer
.
args
.
num_train_epochs
if
max_epochs
is
not
None
:
self
.
trainer
.
args
.
num_train_epochs
=
max_epochs
if
max_steps
is
not
None
:
self
.
trainer
.
args
.
max_steps
=
max_steps
self
.
trainer
.
lr_scheduler
=
self
.
_lr_scheduler_helper
.
call
(
self
.
trainer
.
optimizer
)
if
self
.
_lr_scheduler_helper
else
None
self
.
trainer
.
train
()
self
.
trainer
.
lr_scheduler
=
None
self
.
trainer
.
args
.
max_steps
,
self
.
trainer
.
args
.
num_train_epochs
=
ori_steps
,
ori_epochs
def
finetune
(
self
):
self
.
train
()
def
evaluate
(
self
)
->
float
|
None
|
Tuple
[
float
,
Dict
[
str
,
Any
]]
|
Tuple
[
None
,
Dict
[
str
,
Any
]]:
return
self
.
trainer
.
evaluate
()
def
get_dummy_input
(
self
)
->
Any
:
return
self
.
dummy_input
nni/compression/pytorch/__init__.py
View file @
bbf54a88
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
nni.algorithms.compression.v2.pytorch
import
TorchEvaluator
,
LightningEvaluator
from
nni.algorithms.compression.v2.pytorch
import
TorchEvaluator
,
LightningEvaluator
,
TransformersEvaluator
from
.speedup
import
ModelSpeedup
from
.speedup
import
ModelSpeedup
from
.compressor
import
Compressor
,
Pruner
,
Quantizer
from
.compressor
import
Compressor
,
Pruner
,
Quantizer
from
.utils.apply_compression
import
apply_compression_results
from
.utils.apply_compression
import
apply_compression_results
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment