Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
cdb65dac
Unverified
Commit
cdb65dac
authored
Oct 13, 2021
by
J-shang
Committed by
GitHub
Oct 13, 2021
Browse files
[Model Compression] add scheduler high level api (#4236)
parent
abb4dfdb
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
445 additions
and
20 deletions
+445
-20
examples/model_compress/pruning/v2/iterative_pruning_torch.py
...ples/model_compress/pruning/v2/iterative_pruning_torch.py
+86
-0
examples/model_compress/pruning/v2/scheduler_torch.py
examples/model_compress/pruning/v2/scheduler_torch.py
+3
-11
examples/model_compress/pruning/v2/simple_pruning_torch.py
examples/model_compress/pruning/v2/simple_pruning_torch.py
+2
-0
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+5
-3
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
+1
-1
nni/algorithms/compression/v2/pytorch/pruning/iterative_pruner.py
...rithms/compression/v2/pytorch/pruning/iterative_pruner.py
+259
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+3
-3
test/ut/sdk/test_v2_iterative_pruner_torch.py
test/ut/sdk/test_v2_iterative_pruner_torch.py
+83
-0
test/ut/sdk/test_v2_scheduler.py
test/ut/sdk/test_v2_scheduler.py
+2
-1
test/ut/sdk/test_v2_task_generator.py
test/ut/sdk/test_v2_task_generator.py
+1
-1
No files found.
examples/model_compress/pruning/v2/iterative_pruning_torch.py
0 → 100644
View file @
cdb65dac
from
tqdm
import
tqdm
import
torch
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
AGPPruner
from
examples.model_compress.models.cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
normalize
=
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomCrop
(
32
,
4
),
transforms
.
ToTensor
(),
normalize
,
]),
download
=
True
),
batch_size
=
128
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
128
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
model
.
train
()
for
data
,
target
in
tqdm
(
iterable
=
train_loader
,
desc
=
'Epoch {}'
.
format
(
epoch
)):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
def
finetuner
(
model
):
model
.
train
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
for
data
,
target
in
tqdm
(
iterable
=
train_loader
,
desc
=
'Epoch PFs'
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
def
evaluator
(
model
):
model
.
eval
()
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
tqdm
(
iterable
=
test_loader
,
desc
=
'Test'
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
acc
=
100
*
correct
/
len
(
test_loader
.
dataset
)
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
if
__name__
==
'__main__'
:
model
=
VGG
().
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# pre-train the model
for
i
in
range
(
5
):
trainer
(
model
,
optimizer
,
criterion
,
i
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
dummy_input
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
# pruner = AGPPruner(model, config_list, 'l1', 10, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator)
pruner
=
AGPPruner
(
model
,
config_list
,
'l1'
,
10
,
finetuner
=
finetuner
,
speed_up
=
True
,
dummy_input
=
dummy_input
,
evaluator
=
None
)
pruner
.
compress
()
_
,
model
,
masks
,
_
,
_
=
pruner
.
get_best_result
()
examples/model_compress/pruning/v2/scheduler_torch.py
View file @
cdb65dac
import
functools
from
tqdm
import
tqdm
import
torch
...
...
@@ -77,20 +76,13 @@ if __name__ == '__main__':
for
i
in
range
(
5
):
trainer
(
model
,
optimizer
,
criterion
,
i
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model.
# If you want to initialize pruner at first, you can use the follow code.
# pruner = L1NormPruner(model, config_list)
# pruner._unwrap_model()
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
# pruner._wrap_model()
# No need to pass model and config_list to pruner during initializing when using scheduler.
pruner
=
L1NormPruner
(
None
,
None
)
# you can specify the log_dir, all intermediate results and best result will save under this folder.
# if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`.
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
task_generator
=
AGPTaskGenerator
(
10
,
model
,
config_list
,
log_dir
=
'.'
,
keep_intermediate_result
=
True
)
pruner
=
L1NormPruner
(
model
,
config_list
)
dummy_input
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
)
...
...
examples/model_compress/pruning/v2/simple_pruning_torch.py
View file @
cdb65dac
...
...
@@ -77,6 +77,8 @@ if __name__ == '__main__':
print
(
'
\n
The accuracy after speed up:'
)
evaluator
(
model
)
# Need a new optimizer due to the modules in model will be replaced during speedup.
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
print
(
'
\n
Finetune the model after speed up:'
)
for
i
in
range
(
5
):
trainer
(
model
,
optimizer
,
criterion
,
i
)
...
...
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
cdb65dac
...
...
@@ -37,7 +37,7 @@ class Compressor:
The abstract base pytorch compressor.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
def
__init__
(
self
,
model
:
Optional
[
Module
]
,
config_list
:
Optional
[
List
[
Dict
]
]
):
"""
Parameters
----------
...
...
@@ -46,9 +46,11 @@ class Compressor:
config_list
The config list used by compressor, usually specifies the 'op_types' or 'op_names' that want to compress.
"""
assert
isinstance
(
model
,
Module
)
self
.
is_wrapped
=
False
self
.
reset
(
model
=
model
,
config_list
=
config_list
)
if
model
is
not
None
:
self
.
reset
(
model
=
model
,
config_list
=
config_list
)
else
:
_logger
.
warning
(
'This compressor is not set model and config_list, waiting for reset() or pass this to scheduler.'
)
def
reset
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]):
"""
...
...
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
View file @
cdb65dac
from
.basic_pruner
import
*
from
.basic_scheduler
import
PruningScheduler
from
.
tools
import
AGPTaskGenerator
,
LinearTaskGenerator
,
LotteryTicketTaskGenerator
,
SimulatedAnnealingTaskGenerator
from
.
iterative_pruner
import
*
nni/algorithms/compression/v2/pytorch/pruning/iterative_pruner.py
0 → 100644
View file @
cdb65dac
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
from
typing
import
Dict
,
List
,
Callable
,
Optional
from
torch
import
Tensor
from
torch.nn
import
Module
from
.basic_pruner
import
(
LevelPruner
,
L1NormPruner
,
L2NormPruner
,
FPGMPruner
,
SlimPruner
,
ActivationAPoZRankPruner
,
ActivationMeanRankPruner
,
TaylorFOWeightPruner
,
ADMMPruner
)
from
.basic_scheduler
import
PruningScheduler
from
.tools.task_generator
import
(
LinearTaskGenerator
,
AGPTaskGenerator
,
LotteryTicketTaskGenerator
,
SimulatedAnnealingTaskGenerator
)
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LinearPruner'
,
'AGPPruner'
,
'LotteryTicketPruner'
,
'SimulatedAnnealingPruner'
]
PRUNER_DICT
=
{
'level'
:
LevelPruner
,
'l1'
:
L1NormPruner
,
'l2'
:
L2NormPruner
,
'fpgm'
:
FPGMPruner
,
'slim'
:
SlimPruner
,
'apoz'
:
ActivationAPoZRankPruner
,
'mean_activation'
:
ActivationMeanRankPruner
,
'taylorfo'
:
TaylorFOWeightPruner
,
'admm'
:
ADMMPruner
}
class
IterativePruner
(
PruningScheduler
):
def
_wrap_model
(
self
):
"""
Deprecated function.
"""
_logger
.
warning
(
'Nothing will happen when calling this function.
\
This pruner is an iterative pruner and does not directly wrap the model.'
)
def
_unwrap_model
(
self
):
"""
Deprecated function.
"""
_logger
.
warning
(
'Nothing will happen when calling this function.
\
This pruner is an iterative pruner and does not directly wrap the model.'
)
def
export_model
(
self
,
*
args
,
**
kwargs
):
"""
Deprecated function.
"""
_logger
.
warning
(
'Nothing will happen when calling this function.
\
The best result (and intermediate result if keeped) during iteration is under `log_dir` (default:
\\
.).'
)
class
LinearPruner
(
IterativePruner
):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
finetuner
:
Optional
[
Callable
[[
Module
],
None
]]
=
None
,
speed_up
:
bool
=
False
,
dummy_input
:
Optional
[
Tensor
]
=
None
,
evaluator
:
Optional
[
Callable
[[
Module
],
float
]]
=
None
,
pruning_params
:
dict
=
{}):
task_generator
=
LinearTaskGenerator
(
total_iteration
=
total_iteration
,
origin_model
=
model
,
origin_config_list
=
config_list
,
log_dir
=
log_dir
,
keep_intermediate_result
=
keep_intermediate_result
)
pruner
=
PRUNER_DICT
[
pruning_algorithm
](
None
,
None
,
**
pruning_params
)
super
().
__init__
(
pruner
,
task_generator
,
finetuner
=
finetuner
,
speed_up
=
speed_up
,
dummy_input
=
dummy_input
,
evaluator
=
evaluator
,
reset_weight
=
False
)
class
AGPPruner
(
IterativePruner
):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
finetuner
:
Optional
[
Callable
[[
Module
],
None
]]
=
None
,
speed_up
:
bool
=
False
,
dummy_input
:
Optional
[
Tensor
]
=
None
,
evaluator
:
Optional
[
Callable
[[
Module
],
float
]]
=
None
,
pruning_params
:
dict
=
{}):
task_generator
=
AGPTaskGenerator
(
total_iteration
=
total_iteration
,
origin_model
=
model
,
origin_config_list
=
config_list
,
log_dir
=
log_dir
,
keep_intermediate_result
=
keep_intermediate_result
)
pruner
=
PRUNER_DICT
[
pruning_algorithm
](
None
,
None
,
**
pruning_params
)
super
().
__init__
(
pruner
,
task_generator
,
finetuner
=
finetuner
,
speed_up
=
speed_up
,
dummy_input
=
dummy_input
,
evaluator
=
evaluator
,
reset_weight
=
False
)
class
LotteryTicketPruner
(
IterativePruner
):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
total_iteration : int
The total iteration number.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
evaluator : Optional[Callable[[Module], float]]
Evaluate the pruned model and give a score.
If evaluator is None, the best result refers to the latest result.
reset_weight : bool
If set True, the model weight will reset to the original model weight at the end of each iteration step.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
total_iteration
:
int
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
finetuner
:
Optional
[
Callable
[[
Module
],
None
]]
=
None
,
speed_up
:
bool
=
False
,
dummy_input
:
Optional
[
Tensor
]
=
None
,
evaluator
:
Optional
[
Callable
[[
Module
],
float
]]
=
None
,
reset_weight
:
bool
=
True
,
pruning_params
:
dict
=
{}):
task_generator
=
LotteryTicketTaskGenerator
(
total_iteration
=
total_iteration
,
origin_model
=
model
,
origin_config_list
=
config_list
,
log_dir
=
log_dir
,
keep_intermediate_result
=
keep_intermediate_result
)
pruner
=
PRUNER_DICT
[
pruning_algorithm
](
None
,
None
,
**
pruning_params
)
super
().
__init__
(
pruner
,
task_generator
,
finetuner
=
finetuner
,
speed_up
=
speed_up
,
dummy_input
=
dummy_input
,
evaluator
=
evaluator
,
reset_weight
=
reset_weight
)
class
SimulatedAnnealingPruner
(
IterativePruner
):
"""
Parameters
----------
model : Module
The origin unwrapped pytorch model to be pruned.
config_list : List[Dict]
The origin config list provided by the user. Note that this config_list is directly config the origin model.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
pruning_algorithm : str
Supported pruning algorithm ['level', 'l1', 'l2', 'fpgm', 'slim', 'apoz', 'mean_activation', 'taylorfo', 'admm'].
This iterative pruner will use the chosen corresponding pruner to prune the model in each iteration.
evaluator : Callable[[Module], float]
Evaluate the pruned model and give a score.
start_temperature : float
Start temperature of the simulated annealing process.
stop_temperature : float
Stop temperature of the simulated annealing process.
cool_down_rate : float
Cool down rate of the temperature.
perturbation_magnitude : float
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir : str
The log directory use to saving the result, you can find the best result under this folder.
keep_intermediate_result : bool
If keeping the intermediate result, including intermediate model and masks during each iteration.
finetuner : Optional[Callable[[Module], None]]
The finetuner handled all finetune logic, use a pytorch module as input, will be called in each iteration.
speed_up : bool
If set True, speed up the model in each iteration.
dummy_input : Optional[torch.Tensor]
If `speed_up` is True, `dummy_input` is required for trace the model in speed up.
pruning_params : dict
If the pruner corresponding to the chosen pruning_algorithm has extra parameters, put them as a dict to pass in.
"""
def
__init__
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
],
pruning_algorithm
:
str
,
evaluator
:
Callable
[[
Module
],
float
],
start_temperature
:
float
=
100
,
stop_temperature
:
float
=
20
,
cool_down_rate
:
float
=
0.9
,
perturbation_magnitude
:
float
=
0.35
,
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
finetuner
:
Optional
[
Callable
[[
Module
],
None
]]
=
None
,
speed_up
:
bool
=
False
,
dummy_input
:
Optional
[
Tensor
]
=
None
,
pruning_params
:
dict
=
{}):
task_generator
=
SimulatedAnnealingTaskGenerator
(
origin_model
=
model
,
origin_config_list
=
config_list
,
start_temperature
=
start_temperature
,
stop_temperature
=
stop_temperature
,
cool_down_rate
=
cool_down_rate
,
perturbation_magnitude
=
perturbation_magnitude
,
log_dir
=
log_dir
,
keep_intermediate_result
=
keep_intermediate_result
)
pruner
=
PRUNER_DICT
[
pruning_algorithm
](
None
,
None
,
**
pruning_params
)
super
().
__init__
(
pruner
,
task_generator
,
finetuner
=
finetuner
,
speed_up
=
speed_up
,
dummy_input
=
dummy_input
,
evaluator
=
evaluator
,
reset_weight
=
False
)
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
cdb65dac
...
...
@@ -566,9 +566,9 @@ class TaskGenerator:
return best task id, best compact model, masks on the compact model, score, config list used in this task.
"""
if
self
.
_best_task_id
is
not
None
:
compact_model
=
torch
.
load
(
Path
(
self
.
_log_dir_root
,
'best_result'
,
'
best_
model.pth'
))
compact_model_masks
=
torch
.
load
(
Path
(
self
.
_log_dir_root
,
'best_result'
,
'
best_
masks.pth'
))
with
Path
(
self
.
_log_dir_root
,
'best_result'
,
'
best_
config_list.json'
).
open
(
'r'
)
as
f
:
compact_model
=
torch
.
load
(
Path
(
self
.
_log_dir_root
,
'best_result'
,
'model.pth'
))
compact_model_masks
=
torch
.
load
(
Path
(
self
.
_log_dir_root
,
'best_result'
,
'masks.pth'
))
with
Path
(
self
.
_log_dir_root
,
'best_result'
,
'config_list.json'
).
open
(
'r'
)
as
f
:
config_list
=
json_tricks
.
load
(
f
)
return
self
.
_best_task_id
,
compact_model
,
compact_model_masks
,
self
.
_best_score
,
config_list
return
None
test/ut/sdk/test_v2_iterative_pruner_torch.py
0 → 100644
View file @
cdb65dac
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
random
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.pruning
import
(
LinearPruner
,
AGPPruner
,
LotteryTicketPruner
,
SimulatedAnnealingPruner
)
from
nni.algorithms.compression.v2.pytorch.utils
import
compute_sparsity_mask2compact
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
evaluator
(
model
):
return
random
.
random
()
class
IterativePrunerTestCase
(
unittest
.
TestCase
):
def
test_linear_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
LinearPruner
(
model
,
config_list
,
'level'
,
3
,
log_dir
=
'../../logs'
)
pruner
.
compress
()
_
,
pruned_model
,
masks
,
_
,
_
=
pruner
.
get_best_result
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_agp_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
AGPPruner
(
model
,
config_list
,
'level'
,
3
,
log_dir
=
'../../logs'
)
pruner
.
compress
()
_
,
pruned_model
,
masks
,
_
,
_
=
pruner
.
get_best_result
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_lottery_ticket_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
LotteryTicketPruner
(
model
,
config_list
,
'level'
,
3
,
log_dir
=
'../../logs'
)
pruner
.
compress
()
_
,
pruned_model
,
masks
,
_
,
_
=
pruner
.
get_best_result
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_simulated_annealing_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
SimulatedAnnealingPruner
(
model
,
config_list
,
'level'
,
evaluator
,
start_temperature
=
30
,
log_dir
=
'../../logs'
)
pruner
.
compress
()
_
,
pruned_model
,
masks
,
_
,
_
=
pruner
.
get_best_result
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
if
__name__
==
'__main__'
:
unittest
.
main
()
test/ut/sdk/test_v2_scheduler.py
View file @
cdb65dac
...
...
@@ -6,7 +6,8 @@ import unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.pruning
import
PruningScheduler
,
L1NormPruner
,
AGPTaskGenerator
from
nni.algorithms.compression.v2.pytorch.pruning
import
PruningScheduler
,
L1NormPruner
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
AGPTaskGenerator
class
TorchModel
(
torch
.
nn
.
Module
):
...
...
test/ut/sdk/test_v2_task_generator.py
View file @
cdb65dac
...
...
@@ -8,7 +8,7 @@ import torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.base
import
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch.pruning
import
(
from
nni.algorithms.compression.v2.pytorch.pruning
.tools
import
(
AGPTaskGenerator
,
LinearTaskGenerator
,
LotteryTicketTaskGenerator
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment