Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e3e17f47
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "2c83637db7a1aacab8b2a50dfdda14db7b2f48de"
Unverified
Commit
e3e17f47
authored
Oct 12, 2021
by
J-shang
Committed by
GitHub
Oct 12, 2021
Browse files
[Model Compression] Add Unit Test (#4125)
parent
9a68cdb2
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
773 additions
and
199 deletions
+773
-199
examples/model_compress/pruning/v2/naive_prune_torch.py
examples/model_compress/pruning/v2/naive_prune_torch.py
+0
-153
examples/model_compress/pruning/v2/scheduler_torch.py
examples/model_compress/pruning/v2/scheduler_torch.py
+103
-0
examples/model_compress/pruning/v2/simple_pruning_torch.py
examples/model_compress/pruning/v2/simple_pruning_torch.py
+83
-0
nni/algorithms/compression/v2/pytorch/base/compressor.py
nni/algorithms/compression/v2/pytorch/base/compressor.py
+3
-3
nni/algorithms/compression/v2/pytorch/base/scheduler.py
nni/algorithms/compression/v2/pytorch/base/scheduler.py
+3
-3
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
+2
-0
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
...algorithms/compression/v2/pytorch/pruning/basic_pruner.py
+4
-5
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
+15
-16
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+4
-2
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
...ms/compression/v2/pytorch/pruning/tools/task_generator.py
+16
-16
nni/algorithms/compression/v2/pytorch/utils/__init__.py
nni/algorithms/compression/v2/pytorch/utils/__init__.py
+11
-0
nni/algorithms/compression/v2/pytorch/utils/pruning.py
nni/algorithms/compression/v2/pytorch/utils/pruning.py
+29
-0
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+2
-1
test/ut/sdk/test_v2_pruner_torch.py
test/ut/sdk/test_v2_pruner_torch.py
+158
-0
test/ut/sdk/test_v2_pruning_tools_torch.py
test/ut/sdk/test_v2_pruning_tools_torch.py
+201
-0
test/ut/sdk/test_v2_scheduler.py
test/ut/sdk/test_v2_scheduler.py
+46
-0
test/ut/sdk/test_v2_task_generator.py
test/ut/sdk/test_v2_task_generator.py
+93
-0
No files found.
examples/model_compress/pruning/v2/naive_prune_torch.py
deleted
100644 → 0
View file @
9a68cdb2
import
argparse
import
logging
from
pathlib
import
Path
import
torch
from
torchvision
import
transforms
,
datasets
from
nni.algorithms.compression.v2.pytorch
import
pruning
from
nni.compression.pytorch
import
ModelSpeedup
from
examples.model_compress.models.cifar10.vgg
import
VGG
logging
.
getLogger
().
setLevel
(
logging
.
DEBUG
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
model
=
VGG
().
to
(
device
)
normalize
=
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomCrop
(
32
,
4
),
transforms
.
ToTensor
(),
normalize
,
]),
download
=
True
),
batch_size
=
128
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
200
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
=
None
):
model
.
train
()
for
batch_idx
,
(
data
,
target
)
in
enumerate
(
train_loader
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
if
batch_idx
%
100
==
0
:
print
(
'Train Epoch: {} [{}/{} ({:.0f}%)]
\t
Loss: {:.6f}'
.
format
(
epoch
,
batch_idx
*
len
(
data
),
len
(
train_loader
.
dataset
),
100.
*
batch_idx
/
len
(
train_loader
),
loss
.
item
()))
def
evaluator
(
model
):
model
.
eval
()
criterion
=
torch
.
nn
.
NLLLoss
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
criterion
(
output
,
target
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
test_loader
.
dataset
)
acc
=
100
*
correct
/
len
(
test_loader
.
dataset
)
print
(
'Test Loss: {} Accuracy: {}%
\n
'
.
format
(
test_loss
,
acc
))
return
acc
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
fintune_optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
def
main
(
args
):
if
args
.
pre_train
:
for
i
in
range
(
1
):
trainer
(
model
,
fintune_optimizer
,
criterion
,
epoch
=
i
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity_per_layer'
:
0.8
}]
kwargs
=
{
'model'
:
model
,
'config_list'
:
config_list
,
}
if
args
.
pruner
==
'level'
:
pruner
=
pruning
.
LevelPruner
(
**
kwargs
)
else
:
kwargs
[
'mode'
]
=
args
.
mode
if
kwargs
[
'mode'
]
==
'dependency_aware'
:
kwargs
[
'dummy_input'
]
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
)
if
args
.
pruner
==
'l1norm'
:
pruner
=
pruning
.
L1NormPruner
(
**
kwargs
)
elif
args
.
pruner
==
'l2norm'
:
pruner
=
pruning
.
L2NormPruner
(
**
kwargs
)
elif
args
.
pruner
==
'fpgm'
:
pruner
=
pruning
.
FPGMPruner
(
**
kwargs
)
else
:
kwargs
[
'trainer'
]
=
trainer
kwargs
[
'optimizer'
]
=
optimizer
kwargs
[
'criterion'
]
=
criterion
if
args
.
pruner
==
'slim'
:
kwargs
[
'config_list'
]
=
[{
'op_types'
:
[
'BatchNorm2d'
],
'total_sparsity'
:
0.8
,
'max_sparsity_per_layer'
:
0.9
}]
kwargs
[
'training_epochs'
]
=
1
pruner
=
pruning
.
SlimPruner
(
**
kwargs
)
elif
args
.
pruner
==
'mean_activation'
:
pruner
=
pruning
.
ActivationMeanRankPruner
(
**
kwargs
)
elif
args
.
pruner
==
'apoz'
:
pruner
=
pruning
.
ActivationAPoZRankPruner
(
**
kwargs
)
elif
args
.
pruner
==
'taylorfo'
:
pruner
=
pruning
.
TaylorFOWeightPruner
(
**
kwargs
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
show_pruned_weights
()
if
args
.
speed_up
:
tmp_masks
=
{}
for
name
,
mask
in
masks
.
items
():
tmp_masks
[
name
]
=
{}
tmp_masks
[
name
][
'weight'
]
=
mask
.
get
(
'weight_mask'
)
if
'bias'
in
masks
:
tmp_masks
[
name
][
'bias'
]
=
mask
.
get
(
'bias_mask'
)
torch
.
save
(
tmp_masks
,
Path
(
'./temp_masks.pth'
))
pruner
.
_unwrap_model
()
ModelSpeedup
(
model
,
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
),
Path
(
'./temp_masks.pth'
))
if
args
.
finetune
:
for
i
in
range
(
1
):
trainer
(
pruned_model
,
fintune_optimizer
,
criterion
,
epoch
=
i
)
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
(
description
=
'PyTorch CIFAR10 Example for model comporession'
)
parser
.
add_argument
(
'--pruner'
,
type
=
str
,
default
=
'l1norm'
,
choices
=
[
'level'
,
'l1norm'
,
'l2norm'
,
'slim'
,
'fpgm'
,
'mean_activation'
,
'apoz'
,
'taylorfo'
],
help
=
'pruner to use'
)
parser
.
add_argument
(
'--mode'
,
type
=
str
,
default
=
'normal'
,
choices
=
[
'normal'
,
'dependency_aware'
,
'global'
])
parser
.
add_argument
(
'--pre-train'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to pre-train the model'
)
parser
.
add_argument
(
'--speed-up'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to speed-up the pruned model'
)
parser
.
add_argument
(
'--finetune'
,
action
=
'store_true'
,
default
=
False
,
help
=
'Whether to finetune the pruned model'
)
args
=
parser
.
parse_args
()
main
(
args
)
examples/model_compress/pruning/v2/scheduler_torch.py
0 → 100644
View file @
e3e17f47
import
functools
from
tqdm
import
tqdm
import
torch
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
AGPTaskGenerator
from
nni.algorithms.compression.v2.pytorch.pruning.basic_scheduler
import
PruningScheduler
from
examples.model_compress.models.cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
normalize
=
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomCrop
(
32
,
4
),
transforms
.
ToTensor
(),
normalize
,
]),
download
=
True
),
batch_size
=
128
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
128
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
model
.
train
()
for
data
,
target
in
tqdm
(
iterable
=
train_loader
,
desc
=
'Epoch {}'
.
format
(
epoch
)):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
def
finetuner
(
model
):
model
.
train
()
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
for
data
,
target
in
tqdm
(
iterable
=
train_loader
,
desc
=
'Epoch PFs'
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
def
evaluator
(
model
):
model
.
eval
()
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
tqdm
(
iterable
=
test_loader
,
desc
=
'Test'
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
acc
=
100
*
correct
/
len
(
test_loader
.
dataset
)
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
if
__name__
==
'__main__'
:
model
=
VGG
().
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
# pre-train the model
for
i
in
range
(
5
):
trainer
(
model
,
optimizer
,
criterion
,
i
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
# Make sure initialize task generator at first, this because the model pass to the generator should be an unwrapped model.
# If you want to initialize pruner at first, you can use the follow code.
# pruner = L1NormPruner(model, config_list)
# pruner._unwrap_model()
# task_generator = AGPTaskGenerator(10, model, config_list, log_dir='.', keep_intermediate_result=True)
# pruner._wrap_model()
# you can specify the log_dir, all intermediate results and best result will save under this folder.
# if you don't want to keep intermediate results, you can set `keep_intermediate_result=False`.
task_generator
=
AGPTaskGenerator
(
10
,
model
,
config_list
,
log_dir
=
'.'
,
keep_intermediate_result
=
True
)
pruner
=
L1NormPruner
(
model
,
config_list
)
dummy_input
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
)
# if you just want to keep the final result as the best result, you can pass evaluator as None.
# or the result with the highest score (given by evaluator) will be the best result.
# scheduler = PruningScheduler(pruner, task_generator, finetuner=finetuner, speed_up=True, dummy_input=dummy_input, evaluator=evaluator)
scheduler
=
PruningScheduler
(
pruner
,
task_generator
,
finetuner
=
finetuner
,
speed_up
=
True
,
dummy_input
=
dummy_input
,
evaluator
=
None
)
scheduler
.
compress
()
examples/model_compress/pruning/v2/simple_pruning_torch.py
0 → 100644
View file @
e3e17f47
from
tqdm
import
tqdm
import
torch
from
torchvision
import
datasets
,
transforms
from
nni.algorithms.compression.v2.pytorch.pruning
import
L1NormPruner
from
nni.compression.pytorch.speedup
import
ModelSpeedup
from
examples.model_compress.models.cifar10.vgg
import
VGG
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
normalize
=
transforms
.
Normalize
((
0.4914
,
0.4822
,
0.4465
),
(
0.2023
,
0.1994
,
0.2010
))
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
True
,
transform
=
transforms
.
Compose
([
transforms
.
RandomHorizontalFlip
(),
transforms
.
RandomCrop
(
32
,
4
),
transforms
.
ToTensor
(),
normalize
,
]),
download
=
True
),
batch_size
=
128
,
shuffle
=
True
)
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
datasets
.
CIFAR10
(
'./data'
,
train
=
False
,
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
normalize
,
])),
batch_size
=
128
,
shuffle
=
False
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
def
trainer
(
model
,
optimizer
,
criterion
,
epoch
):
model
.
train
()
for
data
,
target
in
tqdm
(
iterable
=
train_loader
,
desc
=
'Epoch {}'
.
format
(
epoch
)):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
def
evaluator
(
model
):
model
.
eval
()
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
tqdm
(
iterable
=
test_loader
,
desc
=
'Test'
):
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
acc
=
100
*
correct
/
len
(
test_loader
.
dataset
)
print
(
'Accuracy: {}%
\n
'
.
format
(
acc
))
return
acc
if
__name__
==
'__main__'
:
model
=
VGG
().
to
(
device
)
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
print
(
'
\n
Pre-train the model:'
)
for
i
in
range
(
5
):
trainer
(
model
,
optimizer
,
criterion
,
i
)
evaluator
(
model
)
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
L1NormPruner
(
model
,
config_list
)
_
,
masks
=
pruner
.
compress
()
print
(
'
\n
The accuracy with masks:'
)
evaluator
(
model
)
pruner
.
_unwrap_model
()
ModelSpeedup
(
model
,
dummy_input
=
torch
.
rand
(
10
,
3
,
32
,
32
).
to
(
device
),
masks_file
=
'simple_masks.pth'
).
speedup_model
()
print
(
'
\n
The accuracy after speed up:'
)
evaluator
(
model
)
print
(
'
\n
Finetune the model after speed up:'
)
for
i
in
range
(
5
):
trainer
(
model
,
optimizer
,
criterion
,
i
)
evaluator
(
model
)
nni/algorithms/compression/v2/pytorch/base/compressor.py
View file @
e3e17f47
...
@@ -3,13 +3,13 @@
...
@@ -3,13 +3,13 @@
import
collections
import
collections
import
logging
import
logging
from
typing
import
List
,
Dict
,
Optional
,
OrderedDict
,
Tuple
,
Any
from
typing
import
List
,
Dict
,
Optional
,
Tuple
,
Any
import
torch
import
torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.common.graph_utils
import
TorchModuleGraph
from
nni.compression.pytorch.utils
import
get_module_by_name
from
nni.
algorithms.
compression.
v2.
pytorch.utils
import
get_module_by_name
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -149,7 +149,7 @@ class Compressor:
...
@@ -149,7 +149,7 @@ class Compressor:
return
None
return
None
return
ret
return
ret
def
get_modules_wrapper
(
self
)
->
Ordered
Dict
[
str
,
Module
]:
def
get_modules_wrapper
(
self
)
->
Dict
[
str
,
Module
]:
"""
"""
Returns
Returns
-------
-------
...
...
nni/algorithms/compression/v2/pytorch/base/scheduler.py
View file @
e3e17f47
...
@@ -5,12 +5,12 @@ import gc
...
@@ -5,12 +5,12 @@ import gc
import
logging
import
logging
import
os
import
os
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Dict
,
Tuple
,
Literal
,
Optional
from
typing
import
List
,
Dict
,
Tuple
,
Optional
import
json_tricks
import
json_tricks
import
torch
import
torch
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.nn
import
Module
from
torch.tensor
import
Tensor
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -37,7 +37,7 @@ class Task:
...
@@ -37,7 +37,7 @@ class Task:
self
.
masks_path
=
masks_path
self
.
masks_path
=
masks_path
self
.
config_list_path
=
config_list_path
self
.
config_list_path
=
config_list_path
self
.
status
:
Literal
[
'Pending'
,
'Running'
,
'Finished'
]
=
'Pending'
self
.
status
=
'Pending'
self
.
score
:
Optional
[
float
]
=
None
self
.
score
:
Optional
[
float
]
=
None
self
.
state
=
{}
self
.
state
=
{}
...
...
nni/algorithms/compression/v2/pytorch/pruning/__init__.py
View file @
e3e17f47
from
.basic_pruner
import
*
from
.basic_pruner
import
*
from
.basic_scheduler
import
PruningScheduler
from
.tools
import
AGPTaskGenerator
,
LinearTaskGenerator
,
LotteryTicketTaskGenerator
,
SimulatedAnnealingTaskGenerator
nni/algorithms/compression/v2/pytorch/pruning/basic_pruner.py
View file @
e3e17f47
...
@@ -13,8 +13,7 @@ from torch.nn import Module
...
@@ -13,8 +13,7 @@ from torch.nn import Module
from
torch.optim
import
Optimizer
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch.base.pruner
import
Pruner
from
nni.algorithms.compression.v2.pytorch.base.pruner
import
Pruner
from
nni.algorithms.compression.v2.pytorch.utils.config_validation
import
CompressorSchema
from
nni.algorithms.compression.v2.pytorch.utils
import
CompressorSchema
,
config_list_canonical
from
nni.algorithms.compression.v2.pytorch.utils.pruning
import
config_list_canonical
from
.tools
import
(
from
.tools
import
(
DataCollector
,
DataCollector
,
...
@@ -43,7 +42,7 @@ from .tools import (
...
@@ -43,7 +42,7 @@ from .tools import (
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LevelPruner'
,
'L1NormPruner'
,
'L2NormPruner'
,
'FPGMPruner'
,
'SlimPruner'
,
'ActivationPruner'
,
__all__
=
[
'LevelPruner'
,
'L1NormPruner'
,
'L2NormPruner'
,
'FPGMPruner'
,
'SlimPruner'
,
'ActivationPruner'
,
'ActivationAPoZRankPruner'
,
'ActivationMeanRankPruner'
,
'TaylorFOWeightPruner'
]
'ActivationAPoZRankPruner'
,
'ActivationMeanRankPruner'
,
'TaylorFOWeightPruner'
,
'ADMMPruner'
]
NORMAL_SCHEMA
=
{
NORMAL_SCHEMA
=
{
Or
(
'sparsity'
,
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
Or
(
'sparsity'
,
'sparsity_per_layer'
):
And
(
float
,
lambda
n
:
0
<=
n
<
1
),
...
@@ -688,7 +687,7 @@ class ADMMPruner(BasicPruner):
...
@@ -688,7 +687,7 @@ class ADMMPruner(BasicPruner):
Supported keys:
Supported keys:
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity : This is to specify the sparsity for each layer in this config to be compressed.
- sparsity_per_layer : Equals to sparsity.
- sparsity_per_layer : Equals to sparsity.
- rho : Penalty parameters in ADMM algorithm.
- rho : Penalty parameters in ADMM algorithm.
Default: 1e-4.
- op_types : Operation types to prune.
- op_types : Operation types to prune.
- op_names : Operation names to prune.
- op_names : Operation names to prune.
- op_partial_names: An auxiliary field collecting matched op_names in model, then this will convert to op_names.
- op_partial_names: An auxiliary field collecting matched op_names in model, then this will convert to op_names.
...
@@ -744,7 +743,7 @@ class ADMMPruner(BasicPruner):
...
@@ -744,7 +743,7 @@ class ADMMPruner(BasicPruner):
def
patched_criterion
(
output
:
Tensor
,
target
:
Tensor
):
def
patched_criterion
(
output
:
Tensor
,
target
:
Tensor
):
penalty
=
torch
.
tensor
(
0.0
).
to
(
output
.
device
)
penalty
=
torch
.
tensor
(
0.0
).
to
(
output
.
device
)
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
for
name
,
wrapper
in
self
.
get_modules_wrapper
().
items
():
rho
=
wrapper
.
config
[
'rho'
]
rho
=
wrapper
.
config
.
get
(
'rho'
,
1e-4
)
penalty
+=
(
rho
/
2
)
*
torch
.
sqrt
(
torch
.
norm
(
wrapper
.
module
.
weight
-
self
.
Z
[
name
]
+
self
.
U
[
name
]))
penalty
+=
(
rho
/
2
)
*
torch
.
sqrt
(
torch
.
norm
(
wrapper
.
module
.
weight
-
self
.
Z
[
name
]
+
self
.
U
[
name
]))
return
origin_criterion
(
output
,
target
)
+
penalty
return
origin_criterion
(
output
,
target
)
+
penalty
return
patched_criterion
return
patched_criterion
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/base.py
View file @
e3e17f47
...
@@ -452,7 +452,7 @@ class TaskGenerator:
...
@@ -452,7 +452,7 @@ class TaskGenerator:
This class used to generate config list for pruner in each iteration.
This class used to generate config list for pruner in each iteration.
"""
"""
def
__init__
(
self
,
origin_model
:
Module
,
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
def
__init__
(
self
,
origin_model
:
Module
,
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
origin_config_list
:
List
[
Dict
]
=
[],
log_dir
:
str
=
'.'
,
keep_interm
i
diate_result
:
bool
=
False
):
origin_config_list
:
List
[
Dict
]
=
[],
log_dir
:
str
=
'.'
,
keep_interm
e
diate_result
:
bool
=
False
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -465,16 +465,16 @@ class TaskGenerator:
...
@@ -465,16 +465,16 @@ class TaskGenerator:
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
This means the sparsity provided by the origin_masks should also be recorded in the origin_config_list.
log_dir
log_dir
The log directory use to saving the task generator log.
The log directory use to saving the task generator log.
keep_interm
i
diate_result
keep_interm
e
diate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
"""
assert
isinstance
(
origin_model
,
Module
),
'Only support pytorch module.'
assert
isinstance
(
origin_model
,
Module
),
'Only support pytorch module.'
self
.
_log_dir_root
=
Path
(
log_dir
,
datetime
.
now
().
strftime
(
'%Y-%m-%d-%H-%M-%S-%f'
)).
absolute
()
self
.
_log_dir_root
=
Path
(
log_dir
,
datetime
.
now
().
strftime
(
'%Y-%m-%d-%H-%M-%S-%f'
)).
absolute
()
self
.
_log_dir_root
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
_log_dir_root
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
_keep_interm
i
diate_result
=
keep_interm
i
diate_result
self
.
_keep_interm
e
diate_result
=
keep_interm
e
diate_result
self
.
_interm
i
diate_result_dir
=
Path
(
self
.
_log_dir_root
,
'interm
i
diate_result'
)
self
.
_interm
e
diate_result_dir
=
Path
(
self
.
_log_dir_root
,
'interm
e
diate_result'
)
self
.
_interm
i
diate_result_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
self
.
_interm
e
diate_result_dir
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
# save origin data in {log_dir}/origin
# save origin data in {log_dir}/origin
self
.
_origin_model_path
=
Path
(
self
.
_log_dir_root
,
'origin'
,
'model.pth'
)
self
.
_origin_model_path
=
Path
(
self
.
_log_dir_root
,
'origin'
,
'model.pth'
)
...
@@ -506,16 +506,15 @@ class TaskGenerator:
...
@@ -506,16 +506,15 @@ class TaskGenerator:
def
update_best_result
(
self
,
task_result
:
TaskResult
):
def
update_best_result
(
self
,
task_result
:
TaskResult
):
score
=
task_result
.
score
score
=
task_result
.
score
if
score
is
not
None
:
task_id
=
task_result
.
task_id
task_id
=
task_result
.
task_id
task
=
self
.
_tasks
[
task_id
]
task
=
self
.
_tasks
[
task_id
]
task
.
score
=
score
task
.
score
=
score
if
self
.
_best_score
is
None
or
score
>
self
.
_best_score
:
if
self
.
_best_score
is
None
or
score
>
self
.
_best_score
:
self
.
_best_score
=
score
self
.
_best_score
=
score
self
.
_best_task_id
=
task_id
self
.
_best_task_id
=
task_id
with
Path
(
task
.
config_list_path
).
open
(
'r'
)
as
fr
:
with
Path
(
task
.
config_list_path
).
open
(
'r'
)
as
fr
:
best_config_list
=
json_tricks
.
load
(
fr
)
best_config_list
=
json_tricks
.
load
(
fr
)
self
.
_save_data
(
'best_result'
,
task_result
.
compact_model
,
task_result
.
compact_model_masks
,
best_config_list
)
self
.
_save_data
(
'best_result'
,
task_result
.
compact_model
,
task_result
.
compact_model_masks
,
best_config_list
)
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
raise
NotImplementedError
()
raise
NotImplementedError
()
...
@@ -540,7 +539,7 @@ class TaskGenerator:
...
@@ -540,7 +539,7 @@ class TaskGenerator:
self
.
_pending_tasks
.
extend
(
self
.
generate_tasks
(
task_result
))
self
.
_pending_tasks
.
extend
(
self
.
generate_tasks
(
task_result
))
self
.
_dump_tasks_info
()
self
.
_dump_tasks_info
()
if
not
self
.
_keep_interm
i
diate_result
:
if
not
self
.
_keep_interm
e
diate_result
:
self
.
_tasks
[
task_id
].
clean_up
()
self
.
_tasks
[
task_id
].
clean_up
()
def
next
(
self
)
->
Optional
[
Task
]:
def
next
(
self
)
->
Optional
[
Task
]:
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
e3e17f47
...
@@ -103,8 +103,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
...
@@ -103,8 +103,10 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
def
_get_dependency
(
self
):
def
_get_dependency
(
self
):
graph
=
self
.
pruner
.
generate_graph
(
dummy_input
=
self
.
dummy_input
)
graph
=
self
.
pruner
.
generate_graph
(
dummy_input
=
self
.
dummy_input
)
self
.
channel_depen
=
ChannelDependency
(
traced_model
=
graph
.
trace
).
dependency_sets
self
.
pruner
.
_unwrap_model
()
self
.
group_depen
=
GroupDependency
(
traced_model
=
graph
.
trace
).
dependency_sets
self
.
channel_depen
=
ChannelDependency
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
self
.
dummy_input
,
traced_model
=
graph
.
trace
).
dependency_sets
self
.
group_depen
=
GroupDependency
(
model
=
self
.
pruner
.
bound_model
,
dummy_input
=
self
.
dummy_input
,
traced_model
=
graph
.
trace
).
dependency_sets
self
.
pruner
.
_wrap_model
()
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
def
generate_sparsity
(
self
,
metrics
:
Dict
)
->
Dict
[
str
,
Dict
[
str
,
Tensor
]]:
self
.
_get_dependency
()
self
.
_get_dependency
()
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
View file @
e3e17f47
...
@@ -13,7 +13,7 @@ import torch
...
@@ -13,7 +13,7 @@ import torch
from
torch.nn
import
Module
from
torch.nn
import
Module
from
nni.algorithms.compression.v2.pytorch.base
import
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch.base
import
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch.utils
.pruning
import
(
from
nni.algorithms.compression.v2.pytorch.utils
import
(
config_list_canonical
,
config_list_canonical
,
compute_sparsity
,
compute_sparsity
,
get_model_weights_numel
get_model_weights_numel
...
@@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__)
...
@@ -25,7 +25,7 @@ _logger = logging.getLogger(__name__)
class
FunctionBasedTaskGenerator
(
TaskGenerator
):
class
FunctionBasedTaskGenerator
(
TaskGenerator
):
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_interm
i
diate_result
:
bool
=
False
):
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_interm
e
diate_result
:
bool
=
False
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -40,7 +40,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
...
@@ -40,7 +40,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
log_dir
The log directory use to saving the task generator log.
The log directory use to saving the task generator log.
keep_interm
i
diate_result
keep_interm
e
diate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
"""
self
.
current_iteration
=
0
self
.
current_iteration
=
0
...
@@ -48,7 +48,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
...
@@ -48,7 +48,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
self
.
total_iteration
=
total_iteration
self
.
total_iteration
=
total_iteration
super
().
__init__
(
origin_model
,
origin_config_list
=
self
.
target_sparsity
,
origin_masks
=
origin_masks
,
super
().
__init__
(
origin_model
,
origin_config_list
=
self
.
target_sparsity
,
origin_masks
=
origin_masks
,
log_dir
=
log_dir
,
keep_interm
i
diate_result
=
keep_interm
i
diate_result
)
log_dir
=
log_dir
,
keep_interm
e
diate_result
=
keep_interm
e
diate_result
)
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
def
init_pending_tasks
(
self
)
->
List
[
Task
]:
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
...
@@ -62,9 +62,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
...
@@ -62,9 +62,9 @@ class FunctionBasedTaskGenerator(TaskGenerator):
compact_model
=
task_result
.
compact_model
compact_model
=
task_result
.
compact_model
compact_model_masks
=
task_result
.
compact_model_masks
compact_model_masks
=
task_result
.
compact_model_masks
# save interm
i
diate result
# save interm
e
diate result
model_path
=
Path
(
self
.
_interm
i
diate_result_dir
,
'{}_compact_model.pth'
.
format
(
task_result
.
task_id
))
model_path
=
Path
(
self
.
_interm
e
diate_result_dir
,
'{}_compact_model.pth'
.
format
(
task_result
.
task_id
))
masks_path
=
Path
(
self
.
_interm
i
diate_result_dir
,
'{}_compact_model_masks.pth'
.
format
(
task_result
.
task_id
))
masks_path
=
Path
(
self
.
_interm
e
diate_result_dir
,
'{}_compact_model_masks.pth'
.
format
(
task_result
.
task_id
))
torch
.
save
(
compact_model
,
model_path
)
torch
.
save
(
compact_model
,
model_path
)
torch
.
save
(
compact_model_masks
,
masks_path
)
torch
.
save
(
compact_model_masks
,
masks_path
)
...
@@ -81,7 +81,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
...
@@ -81,7 +81,7 @@ class FunctionBasedTaskGenerator(TaskGenerator):
task_id
=
self
.
_task_id_candidate
task_id
=
self
.
_task_id_candidate
new_config_list
=
self
.
generate_config_list
(
self
.
target_sparsity
,
self
.
current_iteration
,
compact2origin_sparsity
)
new_config_list
=
self
.
generate_config_list
(
self
.
target_sparsity
,
self
.
current_iteration
,
compact2origin_sparsity
)
config_list_path
=
Path
(
self
.
_interm
i
diate_result_dir
,
'{}_config_list.json'
.
format
(
task_id
))
config_list_path
=
Path
(
self
.
_interm
e
diate_result_dir
,
'{}_config_list.json'
.
format
(
task_id
))
with
Path
(
config_list_path
).
open
(
'w'
)
as
f
:
with
Path
(
config_list_path
).
open
(
'w'
)
as
f
:
json_tricks
.
dump
(
new_config_list
,
f
,
indent
=
4
)
json_tricks
.
dump
(
new_config_list
,
f
,
indent
=
4
)
...
@@ -124,9 +124,9 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
...
@@ -124,9 +124,9 @@ class LinearTaskGenerator(FunctionBasedTaskGenerator):
class
LotteryTicketTaskGenerator
(
FunctionBasedTaskGenerator
):
class
LotteryTicketTaskGenerator
(
FunctionBasedTaskGenerator
):
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_interm
i
diate_result
:
bool
=
False
):
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_interm
e
diate_result
:
bool
=
False
):
super
().
__init__
(
total_iteration
,
origin_model
,
origin_config_list
,
origin_masks
=
origin_masks
,
log_dir
=
log_dir
,
super
().
__init__
(
total_iteration
,
origin_model
,
origin_config_list
,
origin_masks
=
origin_masks
,
log_dir
=
log_dir
,
keep_interm
i
diate_result
=
keep_interm
i
diate_result
)
keep_interm
e
diate_result
=
keep_interm
e
diate_result
)
self
.
current_iteration
=
1
self
.
current_iteration
=
1
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
compact2origin_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
def
generate_config_list
(
self
,
target_sparsity
:
List
[
Dict
],
iteration
:
int
,
compact2origin_sparsity
:
List
[
Dict
])
->
List
[
Dict
]:
...
@@ -147,7 +147,7 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
...
@@ -147,7 +147,7 @@ class LotteryTicketTaskGenerator(FunctionBasedTaskGenerator):
class
SimulatedAnnealingTaskGenerator
(
TaskGenerator
):
class
SimulatedAnnealingTaskGenerator
(
TaskGenerator
):
def
__init__
(
self
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
def
__init__
(
self
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
start_temperature
:
float
=
100
,
stop_temperature
:
float
=
20
,
cool_down_rate
:
float
=
0.9
,
start_temperature
:
float
=
100
,
stop_temperature
:
float
=
20
,
cool_down_rate
:
float
=
0.9
,
perturbation_magnitude
:
float
=
0.35
,
log_dir
:
str
=
'.'
,
keep_interm
i
diate_result
:
bool
=
False
):
perturbation_magnitude
:
float
=
0.35
,
log_dir
:
str
=
'.'
,
keep_interm
e
diate_result
:
bool
=
False
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -168,7 +168,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -168,7 +168,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
Initial perturbation magnitude to the sparsities. The magnitude decreases with current temperature.
log_dir
log_dir
The log directory use to saving the task generator log.
The log directory use to saving the task generator log.
keep_interm
i
diate_result
keep_interm
e
diate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
If keeping the intermediate result, including intermediate model and masks during each iteration.
"""
"""
self
.
start_temperature
=
start_temperature
self
.
start_temperature
=
start_temperature
...
@@ -186,7 +186,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -186,7 +186,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
self
.
_current_score
=
None
self
.
_current_score
=
None
super
().
__init__
(
origin_model
,
origin_masks
=
origin_masks
,
origin_config_list
=
origin_config_list
,
super
().
__init__
(
origin_model
,
origin_masks
=
origin_masks
,
origin_config_list
=
origin_config_list
,
log_dir
=
log_dir
,
keep_interm
i
diate_result
=
keep_interm
i
diate_result
)
log_dir
=
log_dir
,
keep_interm
e
diate_result
=
keep_interm
e
diate_result
)
def
_adjust_target_sparsity
(
self
):
def
_adjust_target_sparsity
(
self
):
"""
"""
...
@@ -288,8 +288,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -288,8 +288,8 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
origin_model
=
torch
.
load
(
self
.
_origin_model_path
)
origin_masks
=
torch
.
load
(
self
.
_origin_masks_path
)
origin_masks
=
torch
.
load
(
self
.
_origin_masks_path
)
self
.
temp_model_path
=
Path
(
self
.
_interm
i
diate_result_dir
,
'origin_compact_model.pth'
)
self
.
temp_model_path
=
Path
(
self
.
_interm
e
diate_result_dir
,
'origin_compact_model.pth'
)
self
.
temp_masks_path
=
Path
(
self
.
_interm
i
diate_result_dir
,
'origin_compact_model_masks.pth'
)
self
.
temp_masks_path
=
Path
(
self
.
_interm
e
diate_result_dir
,
'origin_compact_model_masks.pth'
)
torch
.
save
(
origin_model
,
self
.
temp_model_path
)
torch
.
save
(
origin_model
,
self
.
temp_model_path
)
torch
.
save
(
origin_masks
,
self
.
temp_masks_path
)
torch
.
save
(
origin_masks
,
self
.
temp_masks_path
)
...
@@ -319,7 +319,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
...
@@ -319,7 +319,7 @@ class SimulatedAnnealingTaskGenerator(TaskGenerator):
task_id
=
self
.
_task_id_candidate
task_id
=
self
.
_task_id_candidate
new_config_list
=
self
.
_recover_real_sparsity
(
deepcopy
(
self
.
_temp_config_list
))
new_config_list
=
self
.
_recover_real_sparsity
(
deepcopy
(
self
.
_temp_config_list
))
config_list_path
=
Path
(
self
.
_interm
i
diate_result_dir
,
'{}_config_list.json'
.
format
(
task_id
))
config_list_path
=
Path
(
self
.
_interm
e
diate_result_dir
,
'{}_config_list.json'
.
format
(
task_id
))
with
Path
(
config_list_path
).
open
(
'w'
)
as
f
:
with
Path
(
config_list_path
).
open
(
'w'
)
as
f
:
json_tricks
.
dump
(
new_config_list
,
f
,
indent
=
4
)
json_tricks
.
dump
(
new_config_list
,
f
,
indent
=
4
)
...
...
nni/algorithms/compression/v2/pytorch/utils/__init__.py
View file @
e3e17f47
from
.config_validation
import
CompressorSchema
from
.pruning
import
(
config_list_canonical
,
unfold_config_list
,
dedupe_config_list
,
compute_sparsity_compact2origin
,
compute_sparsity_mask2compact
,
compute_sparsity
,
get_model_weights_numel
,
get_module_by_name
)
nni/algorithms/compression/v2/pytorch/utils/pruning.py
View file @
e3e17f47
...
@@ -224,3 +224,32 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
...
@@ -224,3 +224,32 @@ def get_model_weights_numel(model: Module, config_list: List[Dict], masks: Dict[
else
:
else
:
model_weights_numel
[
module_name
]
=
module
.
weight
.
data
.
numel
()
model_weights_numel
[
module_name
]
=
module
.
weight
.
data
.
numel
()
return
model_weights_numel
,
masked_rate
return
model_weights_numel
,
masked_rate
# FIXME: to avoid circular import, copy this function in this place
def
get_module_by_name
(
model
,
module_name
):
"""
Get a module specified by its module name
Parameters
----------
model : pytorch model
the pytorch model from which to get its module
module_name : str
the name of the required module
Returns
-------
module, module
the parent module of the required module, the required module
"""
name_list
=
module_name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
if
hasattr
(
model
,
name
):
model
=
getattr
(
model
,
name
)
else
:
return
None
,
None
if
hasattr
(
model
,
name_list
[
-
1
]):
leaf_module
=
getattr
(
model
,
name_list
[
-
1
])
return
model
,
leaf_module
else
:
return
None
,
None
nni/compression/pytorch/utils/shape_dependency.py
View file @
e3e17f47
...
@@ -6,6 +6,7 @@ import logging
...
@@ -6,6 +6,7 @@ import logging
import
torch
import
torch
import
numpy
as
np
import
numpy
as
np
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
from
nni.algorithms.compression.v2.pytorch.base
import
PrunerModuleWrapper
as
PrunerModuleWrapper_v2
from
.utils
import
get_module_by_name
from
.utils
import
get_module_by_name
...
@@ -390,7 +391,7 @@ class GroupDependency(Dependency):
...
@@ -390,7 +391,7 @@ class GroupDependency(Dependency):
"""
"""
node_name
=
node_group
.
name
node_name
=
node_group
.
name
_
,
leaf_module
=
get_module_by_name
(
self
.
model
,
node_name
)
_
,
leaf_module
=
get_module_by_name
(
self
.
model
,
node_name
)
if
isinstance
(
leaf_module
,
PrunerModuleWrapper
):
if
isinstance
(
leaf_module
,
(
PrunerModuleWrapper
,
PrunerModuleWrapper_v2
)
):
leaf_module
=
leaf_module
.
module
leaf_module
=
leaf_module
.
module
assert
isinstance
(
assert
isinstance
(
leaf_module
,
(
torch
.
nn
.
Conv2d
,
torch
.
nn
.
ConvTranspose2d
))
leaf_module
,
(
torch
.
nn
.
Conv2d
,
torch
.
nn
.
ConvTranspose2d
))
...
...
test/ut/sdk/test_v2_pruner_torch.py
0 → 100644
View file @
e3e17f47
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.pruning
import
(
LevelPruner
,
L1NormPruner
,
L2NormPruner
,
SlimPruner
,
FPGMPruner
,
ActivationAPoZRankPruner
,
ActivationMeanRankPruner
,
TaylorFOWeightPruner
,
ADMMPruner
)
from
nni.algorithms.compression.v2.pytorch.utils
import
compute_sparsity_mask2compact
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
trainer
(
model
,
optimizer
,
criterion
):
model
.
train
()
input
=
torch
.
rand
(
10
,
1
,
28
,
28
)
label
=
torch
.
Tensor
(
list
(
range
(
10
))).
type
(
torch
.
LongTensor
)
optimizer
.
zero_grad
()
output
=
model
(
input
)
loss
=
criterion
(
output
,
label
)
loss
.
backward
()
optimizer
.
step
()
def
get_optimizer
(
model
):
return
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
class
PrunerTestCase
(
unittest
.
TestCase
):
def
test_level_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
LevelPruner
(
model
=
model
,
config_list
=
config_list
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_l1_norm_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
L1NormPruner
(
model
=
model
,
config_list
=
config_list
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_l2_norm_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
L2NormPruner
(
model
=
model
,
config_list
=
config_list
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_fpgm_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
FPGMPruner
(
model
=
model
,
config_list
=
config_list
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_slim_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'BatchNorm2d'
],
'total_sparsity'
:
0.8
}]
pruner
=
SlimPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_epochs
=
1
,
scale
=
0.001
,
mode
=
'global'
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_activation_apoz_rank_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
ActivationAPoZRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
1
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_activation_mean_rank_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
ActivationMeanRankPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
1
,
activation
=
'relu'
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_taylor_fo_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
pruner
=
TaylorFOWeightPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
training_batches
=
1
,
mode
=
'dependency_aware'
,
dummy_input
=
torch
.
rand
(
10
,
1
,
28
,
28
))
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
def
test_admm_pruner
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
,
'rho'
:
1e-3
}]
pruner
=
ADMMPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
optimizer
=
get_optimizer
(
model
),
criterion
=
criterion
,
iterations
=
2
,
training_epochs
=
1
)
pruned_model
,
masks
=
pruner
.
compress
()
pruner
.
_unwrap_model
()
sparsity_list
=
compute_sparsity_mask2compact
(
pruned_model
,
masks
,
config_list
)
assert
0.79
<
sparsity_list
[
0
][
'total_sparsity'
]
<
0.81
if
__name__
==
'__main__'
:
unittest
.
main
()
test/ut/sdk/test_v2_pruning_tools_torch.py
0 → 100644
View file @
e3e17f47
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.base
import
Pruner
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
(
WeightDataCollector
,
WeightTrainerBasedDataCollector
,
SingleHookTrainerBasedDataCollector
)
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
(
NormMetricsCalculator
,
MultiDataNormMetricsCalculator
,
DistMetricsCalculator
,
APoZRankMetricsCalculator
,
MeanRankMetricsCalculator
)
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
(
NormalSparsityAllocator
,
GlobalSparsityAllocator
)
from
nni.algorithms.compression.v2.pytorch.pruning.tools.base
import
HookCollectorInfo
from
nni.algorithms.compression.v2.pytorch.utils
import
get_module_by_name
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
trainer
(
model
,
optimizer
,
criterion
):
model
.
train
()
input
=
torch
.
rand
(
10
,
1
,
28
,
28
)
label
=
torch
.
Tensor
(
list
(
range
(
10
))).
type
(
torch
.
LongTensor
)
optimizer
.
zero_grad
()
output
=
model
(
input
)
loss
=
criterion
(
output
,
label
)
loss
.
backward
()
optimizer
.
step
()
def
get_optimizer
(
model
):
return
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
0.1
,
momentum
=
0.9
,
weight_decay
=
5e-4
)
criterion
=
torch
.
nn
.
CrossEntropyLoss
()
class
PruningToolsTestCase
(
unittest
.
TestCase
):
def
test_data_collector
(
self
):
model
=
TorchModel
()
w1
=
torch
.
rand
(
5
,
1
,
5
,
5
)
w2
=
torch
.
rand
(
10
,
5
,
5
,
5
)
model
.
conv1
.
weight
.
data
=
w1
model
.
conv2
.
weight
.
data
=
w2
config_list
=
[{
'op_types'
:
[
'Conv2d'
]}]
pruner
=
Pruner
(
model
,
config_list
)
# Test WeightDataCollector
data_collector
=
WeightDataCollector
(
pruner
)
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
# Test WeightTrainerBasedDataCollector
def
opt_after
():
model
.
conv1
.
module
.
weight
.
data
=
torch
.
ones
(
5
,
1
,
5
,
5
)
model
.
conv2
.
module
.
weight
.
data
=
torch
.
ones
(
10
,
5
,
5
,
5
)
data_collector
=
WeightTrainerBasedDataCollector
(
pruner
,
trainer
,
get_optimizer
(
model
),
criterion
,
1
,
opt_after_tasks
=
[
opt_after
])
data
=
data_collector
.
collect
()
assert
all
(
torch
.
equal
(
get_module_by_name
(
model
,
module_name
)[
1
].
module
.
weight
.
data
,
data
[
module_name
])
for
module_name
in
[
'conv1'
,
'conv2'
])
assert
all
(
t
.
numel
()
==
(
t
==
1
).
type_as
(
t
).
sum
().
item
()
for
t
in
data
.
values
())
# Test SingleHookTrainerBasedDataCollector
def
_collector
(
buffer
,
weight_tensor
):
def
collect_taylor
(
grad
):
if
len
(
buffer
)
<
2
:
buffer
.
append
(
grad
.
clone
().
detach
())
return
collect_taylor
hook_targets
=
{
'conv1'
:
model
.
conv1
.
module
.
weight
,
'conv2'
:
model
.
conv2
.
module
.
weight
}
collector_info
=
HookCollectorInfo
(
hook_targets
,
'tensor'
,
_collector
)
data_collector
=
SingleHookTrainerBasedDataCollector
(
pruner
,
trainer
,
get_optimizer
(
model
),
criterion
,
2
,
collector_infos
=
[
collector_info
])
data
=
data_collector
.
collect
()
assert
all
(
len
(
t
)
==
2
for
t
in
data
.
values
())
def
test_metrics_calculator
(
self
):
# Test NormMetricsCalculator
metrics_calculator
=
NormMetricsCalculator
(
dim
=
0
,
p
=
2
)
data
=
{
'1'
:
torch
.
ones
(
3
,
3
,
3
),
'2'
:
torch
.
ones
(
4
,
4
)
*
2
}
result
=
{
'1'
:
torch
.
ones
(
3
)
*
3
,
'2'
:
torch
.
ones
(
4
)
*
4
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test DistMetricsCalculator
metrics_calculator
=
DistMetricsCalculator
(
dim
=
0
,
p
=
2
)
data
=
{
'1'
:
torch
.
tensor
([[
1
,
2
],
[
4
,
6
]],
dtype
=
torch
.
float32
),
'2'
:
torch
.
tensor
([[
0
,
0
],
[
1
,
1
]],
dtype
=
torch
.
float32
)
}
result
=
{
'1'
:
torch
.
tensor
([
5
,
5
],
dtype
=
torch
.
float32
),
'2'
:
torch
.
sqrt
(
torch
.
tensor
([
2
,
2
],
dtype
=
torch
.
float32
))
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test MultiDataNormMetricsCalculator
metrics_calculator
=
MultiDataNormMetricsCalculator
(
dim
=
0
,
p
=
1
)
data
=
{
'1'
:
[
torch
.
ones
(
3
,
3
,
3
),
torch
.
ones
(
3
,
3
,
3
)
*
2
],
'2'
:
[
torch
.
ones
(
4
,
4
),
torch
.
ones
(
4
,
4
)
*
2
]
}
result
=
{
'1'
:
torch
.
ones
(
3
)
*
27
,
'2'
:
torch
.
ones
(
4
)
*
12
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test APoZRankMetricsCalculator
metrics_calculator
=
APoZRankMetricsCalculator
(
dim
=
1
)
data
=
{
'1'
:
[
torch
.
tensor
([[
1
,
0
],
[
0
,
1
]],
dtype
=
torch
.
float32
),
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
float32
)],
'2'
:
[
torch
.
tensor
([[
1
,
0
,
1
],
[
0
,
1
,
0
]],
dtype
=
torch
.
float32
),
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
}
result
=
{
'1'
:
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
torch
.
float32
),
'2'
:
torch
.
tensor
([
0.25
,
0.25
,
0.5
],
dtype
=
torch
.
float32
)
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
# Test MeanRankMetricsCalculator
metrics_calculator
=
MeanRankMetricsCalculator
(
dim
=
1
)
data
=
{
'1'
:
[
torch
.
tensor
([[
1
,
0
],
[
0
,
1
]],
dtype
=
torch
.
float32
),
torch
.
tensor
([[
0
,
1
],
[
1
,
0
]],
dtype
=
torch
.
float32
)],
'2'
:
[
torch
.
tensor
([[
1
,
0
,
1
],
[
0
,
1
,
0
]],
dtype
=
torch
.
float32
),
torch
.
tensor
([[
0
,
0
,
1
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float32
)]
}
result
=
{
'1'
:
torch
.
tensor
([
0.5
,
0.5
],
dtype
=
torch
.
float32
),
'2'
:
torch
.
tensor
([
0.25
,
0.25
,
0.5
],
dtype
=
torch
.
float32
)
}
metrics
=
metrics_calculator
.
calculate_metrics
(
data
)
assert
all
(
torch
.
equal
(
result
[
k
],
v
)
for
k
,
v
in
metrics
.
items
())
def
test_sparsity_allocator
(
self
):
# Test NormalSparsityAllocator
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'total_sparsity'
:
0.8
}]
pruner
=
Pruner
(
model
,
config_list
)
metrics
=
{
'conv1'
:
torch
.
rand
(
5
,
1
,
5
,
5
),
'conv2'
:
torch
.
rand
(
10
,
5
,
5
,
5
)
}
sparsity_allocator
=
NormalSparsityAllocator
(
pruner
)
masks
=
sparsity_allocator
.
generate_sparsity
(
metrics
)
assert
all
(
v
[
'weight'
].
sum
()
/
v
[
'weight'
].
numel
()
==
0.2
for
k
,
v
in
masks
.
items
())
# Test GlobalSparsityAllocator
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'total_sparsity'
:
0.8
}]
pruner
=
Pruner
(
model
,
config_list
)
sparsity_allocator
=
GlobalSparsityAllocator
(
pruner
)
masks
=
sparsity_allocator
.
generate_sparsity
(
metrics
)
total_elements
,
total_masked_elements
=
0
,
0
for
t
in
masks
.
values
():
total_elements
+=
t
[
'weight'
].
numel
()
total_masked_elements
+=
t
[
'weight'
].
sum
().
item
()
assert
total_masked_elements
/
total_elements
==
0.2
if
__name__
==
'__main__'
:
unittest
.
main
()
test/ut/sdk/test_v2_scheduler.py
0 → 100644
View file @
e3e17f47
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.pruning
import
PruningScheduler
,
L1NormPruner
,
AGPTaskGenerator
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
class
PruningSchedulerTestCase
(
unittest
.
TestCase
):
def
test_pruning_scheduler
(
self
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
task_generator
=
AGPTaskGenerator
(
1
,
model
,
config_list
)
pruner
=
L1NormPruner
(
model
,
config_list
)
scheduler
=
PruningScheduler
(
pruner
,
task_generator
)
scheduler
.
compress
()
if
__name__
==
'__main__'
:
unittest
.
main
()
test/ut/sdk/test_v2_task_generator.py
0 → 100644
View file @
e3e17f47
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
List
import
unittest
import
torch
import
torch.nn.functional
as
F
from
nni.algorithms.compression.v2.pytorch.base
import
Task
,
TaskResult
from
nni.algorithms.compression.v2.pytorch.pruning
import
(
AGPTaskGenerator
,
LinearTaskGenerator
,
LotteryTicketTaskGenerator
,
SimulatedAnnealingTaskGenerator
)
class
TorchModel
(
torch
.
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
torch
.
nn
.
Conv2d
(
1
,
5
,
5
,
1
)
self
.
bn1
=
torch
.
nn
.
BatchNorm2d
(
5
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
5
,
10
,
5
,
1
)
self
.
bn2
=
torch
.
nn
.
BatchNorm2d
(
10
)
self
.
fc1
=
torch
.
nn
.
Linear
(
4
*
4
*
10
,
100
)
self
.
fc2
=
torch
.
nn
.
Linear
(
100
,
10
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
bn1
(
self
.
conv1
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
F
.
relu
(
self
.
bn2
(
self
.
conv2
(
x
)))
x
=
F
.
max_pool2d
(
x
,
2
,
2
)
x
=
x
.
view
(
-
1
,
4
*
4
*
10
)
x
=
F
.
relu
(
self
.
fc1
(
x
))
x
=
self
.
fc2
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
run_task_generator
(
task_generator_type
):
model
=
TorchModel
()
config_list
=
[{
'op_types'
:
[
'Conv2d'
],
'sparsity'
:
0.8
}]
if
task_generator_type
==
'agp'
:
task_generator
=
AGPTaskGenerator
(
5
,
model
,
config_list
)
elif
task_generator_type
==
'linear'
:
task_generator
=
LinearTaskGenerator
(
5
,
model
,
config_list
)
elif
task_generator_type
==
'lottery_ticket'
:
task_generator
=
LotteryTicketTaskGenerator
(
5
,
model
,
config_list
)
elif
task_generator_type
==
'simulated_annealing'
:
task_generator
=
SimulatedAnnealingTaskGenerator
(
model
,
config_list
)
count
=
run_task_generator_
(
task_generator
)
if
task_generator_type
==
'agp'
:
assert
count
==
6
elif
task_generator_type
==
'linear'
:
assert
count
==
6
elif
task_generator_type
==
'lottery_ticket'
:
assert
count
==
6
elif
task_generator_type
==
'simulated_annealing'
:
assert
count
==
17
def
run_task_generator_
(
task_generator
):
task
=
task_generator
.
next
()
factor
=
0.9
count
=
0
while
task
is
not
None
:
factor
=
factor
**
2
count
+=
1
task_result
=
TaskResult
(
task
.
task_id
,
TorchModel
(),
{},
{},
1
-
factor
)
task_generator
.
receive_task_result
(
task_result
)
task
=
task_generator
.
next
()
return
count
class
TaskGenerator
(
unittest
.
TestCase
):
def
test_agp_task_generator
(
self
):
run_task_generator
(
'agp'
)
def
test_linear_task_generator
(
self
):
run_task_generator
(
'linear'
)
def
test_lottery_ticket_task_generator
(
self
):
run_task_generator
(
'lottery_ticket'
)
def
test_simulated_annealing_task_generator
(
self
):
run_task_generator
(
'simulated_annealing'
)
if
__name__
==
'__main__'
:
unittest
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment