Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
21539654
"tests/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5a658d614edb2ec97c9772541e0c6086a2549960"
Unverified
Commit
21539654
authored
May 24, 2022
by
J-shang
Committed by
GitHub
May 24, 2022
Browse files
[Compression] compression experiment - step 1 (#4836)
parent
2fc47247
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
734 additions
and
8 deletions
+734
-8
examples/model_compress/experimental/compression_experiment/demo.py
...odel_compress/experimental/compression_experiment/demo.py
+43
-0
examples/model_compress/experimental/compression_experiment/vessel.py
...el_compress/experimental/compression_experiment/vessel.py
+99
-0
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
...ompression/v2/pytorch/pruning/tools/sparsity_allocator.py
+25
-5
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
...ms/compression/v2/pytorch/pruning/tools/task_generator.py
+9
-3
nni/compression/experiment/config/__init__.py
nni/compression/experiment/config/__init__.py
+7
-0
nni/compression/experiment/config/compression.py
nni/compression/experiment/config/compression.py
+72
-0
nni/compression/experiment/config/pruner.py
nni/compression/experiment/config/pruner.py
+32
-0
nni/compression/experiment/config/quantizer.py
nni/compression/experiment/config/quantizer.py
+14
-0
nni/compression/experiment/config/utils.py
nni/compression/experiment/config/utils.py
+185
-0
nni/compression/experiment/config/vessel.py
nni/compression/experiment/config/vessel.py
+97
-0
nni/compression/experiment/experiment.py
nni/compression/experiment/experiment.py
+105
-0
nni/compression/experiment/trial_entry.py
nni/compression/experiment/trial_entry.py
+46
-0
No files found.
examples/model_compress/experimental/compression_experiment/demo.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
pathlib
import
Path
import
torch
from
torch.optim
import
Adam
import
nni
from
nni.compression.experiment.experiment
import
CompressionExperiment
from
nni.compression.experiment.config
import
CompressionExperimentConfig
,
TaylorFOWeightPrunerConfig
from
vessel
import
LeNet
,
finetuner
,
evaluator
,
trainer
,
criterion
,
device
model
=
LeNet
().
to
(
device
)
# pre-training model
finetuner
(
model
)
optimizer
=
nni
.
trace
(
Adam
)(
model
.
parameters
())
dummy_input
=
torch
.
rand
(
16
,
1
,
28
,
28
).
to
(
device
)
# normal experiment setting, no need to set search_space and trial_command
config
=
CompressionExperimentConfig
(
'local'
)
config
.
experiment_name
=
'auto compression torch example'
config
.
trial_concurrency
=
1
config
.
max_trial_number
=
10
config
.
trial_code_directory
=
Path
(
__file__
).
parent
config
.
tuner
.
name
=
'TPE'
config
.
tuner
.
class_args
[
'optimize_mode'
]
=
'maximize'
# compression experiment specific setting
# single float value means the expected remaining ratio upper limit for flops & params, lower limit for metric
config
.
compression_setting
.
flops
=
0.2
config
.
compression_setting
.
params
=
0.5
config
.
compression_setting
.
module_types
=
[
'Conv2d'
,
'Linear'
]
config
.
compression_setting
.
exclude_module_names
=
[
'fc2'
]
config
.
compression_setting
.
pruners
=
[
TaylorFOWeightPrunerConfig
()]
experiment
=
CompressionExperiment
(
config
,
model
,
finetuner
,
evaluator
,
dummy_input
,
trainer
,
optimizer
,
criterion
,
device
)
experiment
.
run
(
8080
)
examples/model_compress/experimental/compression_experiment/vessel.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.optim
import
Adam
from
torchvision
import
datasets
,
transforms
import
nni
@
nni
.
trace
class
LeNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
nn
.
Conv2d
(
32
,
64
,
3
,
1
)
self
.
dropout1
=
nn
.
Dropout2d
(
0.25
)
self
.
dropout2
=
nn
.
Dropout2d
(
0.5
)
self
.
fc1
=
nn
.
Linear
(
9216
,
128
)
self
.
fc2
=
nn
.
Linear
(
128
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
conv2
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
x
,
2
)
x
=
self
.
dropout1
(
x
)
x
=
torch
.
flatten
(
x
,
1
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
_use_cuda
=
True
device
=
torch
.
device
(
"cuda"
if
_use_cuda
else
"cpu"
)
_train_kwargs
=
{
'batch_size'
:
64
}
_test_kwargs
=
{
'batch_size'
:
1000
}
if
_use_cuda
:
_cuda_kwargs
=
{
'num_workers'
:
1
,
'pin_memory'
:
True
,
'shuffle'
:
True
}
_train_kwargs
.
update
(
_cuda_kwargs
)
_test_kwargs
.
update
(
_cuda_kwargs
)
_transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))
])
_train_loader
=
None
_test_loader
=
None
def
trainer
(
model
,
optimizer
,
criterion
):
global
_train_loader
if
_train_loader
is
None
:
dataset
=
datasets
.
MNIST
(
'./data'
,
train
=
True
,
download
=
True
,
transform
=
_transform
)
_train_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
**
_train_kwargs
)
model
.
train
()
for
data
,
target
in
_train_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
loss
=
criterion
(
output
,
target
)
loss
.
backward
()
optimizer
.
step
()
def
evaluator
(
model
):
global
_test_loader
if
_test_loader
is
None
:
dataset
=
datasets
.
MNIST
(
'./data'
,
train
=
False
,
transform
=
_transform
,
download
=
True
)
_test_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
**
_test_kwargs
)
model
.
eval
()
test_loss
=
0
correct
=
0
with
torch
.
no_grad
():
for
data
,
target
in
_test_loader
:
data
,
target
=
data
.
to
(
device
),
target
.
to
(
device
)
output
=
model
(
data
)
test_loss
+=
F
.
nll_loss
(
output
,
target
,
reduction
=
'sum'
).
item
()
pred
=
output
.
argmax
(
dim
=
1
,
keepdim
=
True
)
correct
+=
pred
.
eq
(
target
.
view_as
(
pred
)).
sum
().
item
()
test_loss
/=
len
(
_test_loader
.
dataset
)
acc
=
100
*
correct
/
len
(
_test_loader
.
dataset
)
print
(
'
\n
Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)
\n
'
.
format
(
test_loss
,
correct
,
len
(
_test_loader
.
dataset
),
acc
))
return
acc
criterion
=
F
.
nll_loss
def
finetuner
(
model
:
nn
.
Module
):
optimizer
=
Adam
(
model
.
parameters
())
for
i
in
range
(
3
):
trainer
(
model
,
optimizer
,
criterion
)
nni/algorithms/compression/v2/pytorch/pruning/tools/sparsity_allocator.py
View file @
21539654
...
@@ -159,7 +159,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
...
@@ -159,7 +159,7 @@ class GlobalSparsityAllocator(SparsityAllocator):
class
Conv2dDependencyAwareAllocator
(
SparsityAllocator
):
class
Conv2dDependencyAwareAllocator
(
SparsityAllocator
):
"""
"""
A
specify
allocator for Conv2d with dependency-aware.
A
n
allocator
specific
for Conv2d with dependency-aware.
"""
"""
def
__init__
(
self
,
pruner
:
Pruner
,
dim
:
int
,
dummy_input
:
Any
):
def
__init__
(
self
,
pruner
:
Pruner
,
dim
:
int
,
dummy_input
:
Any
):
...
@@ -178,26 +178,42 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
...
@@ -178,26 +178,42 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
self
.
_get_dependency
()
self
.
_get_dependency
()
masks
=
{}
masks
=
{}
grouped_metrics
=
{}
grouped_metrics
=
{}
grouped_names
=
set
()
# combine metrics with channel dependence
for
idx
,
names
in
enumerate
(
self
.
channel_depen
):
for
idx
,
names
in
enumerate
(
self
.
channel_depen
):
grouped_metric
=
{
name
:
metrics
[
name
]
for
name
in
names
if
name
in
metrics
}
grouped_metric
=
{
name
:
metrics
[
name
]
for
name
in
names
if
name
in
metrics
}
grouped_names
.
update
(
grouped_metric
.
keys
())
if
self
.
continuous_mask
:
if
self
.
continuous_mask
:
for
name
,
metric
in
grouped_metric
.
items
():
for
name
,
metric
in
grouped_metric
.
items
():
metric
*=
self
.
_compress_mask
(
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
)
# type: ignore
metric
*=
self
.
_compress_mask
(
self
.
pruner
.
get_modules_wrapper
()[
name
].
weight_mask
)
# type: ignore
if
len
(
grouped_metric
)
>
0
:
if
len
(
grouped_metric
)
>
0
:
grouped_metrics
[
idx
]
=
grouped_metric
grouped_metrics
[
idx
]
=
grouped_metric
# ungrouped metrics stand alone as a group
ungrouped_names
=
set
(
metrics
.
keys
()).
difference
(
grouped_names
)
for
name
in
ungrouped_names
:
idx
+=
1
# type: ignore
grouped_metrics
[
idx
]
=
{
name
:
metrics
[
name
]}
# generate masks
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
for
_
,
group_metric_dict
in
grouped_metrics
.
items
():
group_metric
=
self
.
_group_metric_calculate
(
group_metric_dict
)
group_metric
=
self
.
_group_metric_calculate
(
group_metric_dict
)
sparsities
=
{
name
:
self
.
pruner
.
get_modules_wrapper
()[
name
].
config
[
'total_sparsity'
]
for
name
in
group_metric_dict
.
keys
()}
sparsities
=
{
name
:
self
.
pruner
.
get_modules_wrapper
()[
name
].
config
[
'total_sparsity'
]
for
name
in
group_metric_dict
.
keys
()}
min_sparsity
=
min
(
sparsities
.
values
())
min_sparsity
=
min
(
sparsities
.
values
())
conv2d_groups
=
[
self
.
group_depen
[
name
]
for
name
in
group_metric_dict
.
keys
()]
# generate group mask
max_conv2d_group
=
np
.
lcm
.
reduce
(
conv2d_groups
)
conv2d_groups
,
group_mask
=
[],
[]
for
name
in
group_metric_dict
.
keys
():
if
name
in
self
.
group_depen
:
conv2d_groups
.
append
(
self
.
group_depen
[
name
])
else
:
# not in group_depen means not a Conv2d layer, in this case, assume the group number is 1
conv2d_groups
.
append
(
1
)
max_conv2d_group
=
np
.
lcm
.
reduce
(
conv2d_groups
)
pruned_per_conv2d_group
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
*
min_sparsity
)
pruned_per_conv2d_group
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
*
min_sparsity
)
conv2d_group_step
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
)
conv2d_group_step
=
int
(
group_metric
.
numel
()
/
max_conv2d_group
)
group_mask
=
[]
for
gid
in
range
(
max_conv2d_group
):
for
gid
in
range
(
max_conv2d_group
):
_start
=
gid
*
conv2d_group_step
_start
=
gid
*
conv2d_group_step
_end
=
(
gid
+
1
)
*
conv2d_group_step
_end
=
(
gid
+
1
)
*
conv2d_group_step
...
@@ -209,11 +225,15 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
...
@@ -209,11 +225,15 @@ class Conv2dDependencyAwareAllocator(SparsityAllocator):
group_mask
.
append
(
conv2d_group_mask
)
group_mask
.
append
(
conv2d_group_mask
)
group_mask
=
torch
.
cat
(
group_mask
,
dim
=
0
)
group_mask
=
torch
.
cat
(
group_mask
,
dim
=
0
)
# generate final mask
for
name
,
metric
in
group_metric_dict
.
items
():
for
name
,
metric
in
group_metric_dict
.
items
():
# We assume the metric value are all positive right now.
# We assume the metric value are all positive right now.
metric
=
metric
*
group_mask
metric
=
metric
*
group_mask
pruned_num
=
int
(
sparsities
[
name
]
*
len
(
metric
))
pruned_num
=
int
(
sparsities
[
name
]
*
len
(
metric
))
threshold
=
torch
.
topk
(
metric
,
pruned_num
,
largest
=
False
)[
0
].
max
()
if
pruned_num
==
0
:
threshold
=
metric
.
min
()
-
1
else
:
threshold
=
torch
.
topk
(
metric
,
pruned_num
,
largest
=
False
)[
0
].
max
()
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
mask
=
torch
.
gt
(
metric
,
threshold
).
type_as
(
metric
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
masks
[
name
]
=
self
.
_expand_mask
(
name
,
mask
)
if
self
.
continuous_mask
:
if
self
.
continuous_mask
:
...
...
nni/algorithms/compression/v2/pytorch/pruning/tools/task_generator.py
View file @
21539654
...
@@ -25,7 +25,8 @@ _logger = logging.getLogger(__name__)
...
@@ -25,7 +25,8 @@ _logger = logging.getLogger(__name__)
class
FunctionBasedTaskGenerator
(
TaskGenerator
):
class
FunctionBasedTaskGenerator
(
TaskGenerator
):
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
def
__init__
(
self
,
total_iteration
:
int
,
origin_model
:
Module
,
origin_config_list
:
List
[
Dict
],
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
):
origin_masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{},
log_dir
:
str
=
'.'
,
keep_intermediate_result
:
bool
=
False
,
skip_first_iteration
:
bool
=
False
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -39,16 +40,21 @@ class FunctionBasedTaskGenerator(TaskGenerator):
...
@@ -39,16 +40,21 @@ class FunctionBasedTaskGenerator(TaskGenerator):
origin_masks
origin_masks
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
The pre masks on the origin model. This mask maybe user-defined or maybe generate by previous pruning.
log_dir
log_dir
The log directory use to sav
ing
the task generator log.
The log directory use
d
to sav
e
the task generator log.
keep_intermediate_result
keep_intermediate_result
If keeping the intermediate result, including intermediate model and masks during each iteration.
If keeping the intermediate result, including intermediate model and masks during each iteration.
skip_first_iteration
If skipping the first iteration, the iteration counter will start at 1.
In these function-based iterative pruning algorithms, iteration `0` means a warm up stage with `sparsity = 0`.
If the `original_model` is a pre-trained model, the first iteration is usually can be skipped.
"""
"""
self
.
total_iteration
=
total_iteration
self
.
total_iteration
=
total_iteration
self
.
skip_first_iteration
=
skip_first_iteration
super
().
__init__
(
origin_model
,
origin_config_list
=
origin_config_list
,
origin_masks
=
origin_masks
,
super
().
__init__
(
origin_model
,
origin_config_list
=
origin_config_list
,
origin_masks
=
origin_masks
,
log_dir
=
log_dir
,
keep_intermediate_result
=
keep_intermediate_result
)
log_dir
=
log_dir
,
keep_intermediate_result
=
keep_intermediate_result
)
def
reset
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]
=
[],
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{}):
def
reset
(
self
,
model
:
Module
,
config_list
:
List
[
Dict
]
=
[],
masks
:
Dict
[
str
,
Dict
[
str
,
Tensor
]]
=
{}):
self
.
current_iteration
=
0
self
.
current_iteration
=
1
if
self
.
skip_first_iteration
else
0
self
.
target_sparsity
=
config_list_canonical
(
model
,
config_list
)
self
.
target_sparsity
=
config_list_canonical
(
model
,
config_list
)
super
().
reset
(
model
,
config_list
=
config_list
,
masks
=
masks
)
super
().
reset
(
model
,
config_list
=
config_list
,
masks
=
masks
)
...
...
nni/compression/experiment/config/__init__.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.compression
import
CompressionConfig
,
CompressionExperimentConfig
from
.utils
import
generate_compression_search_space
from
.vessel
import
CompressionVessel
from
.pruner
import
TaylorFOWeightPrunerConfig
nni/compression/experiment/config/compression.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__
=
[
'CompressionConfig'
,
'CompressionExperimentConfig'
]
from
dataclasses
import
dataclass
from
typing
import
List
,
Optional
,
Type
,
Union
from
torch.nn
import
Module
from
nni.experiment.config
import
ExperimentConfig
from
nni.experiment.config.base
import
ConfigBase
from
.pruner
import
PrunerConfig
from
.quantizer
import
QuantizerConfig
@
dataclass
class
CompressionConfig
(
ConfigBase
):
"""
Attributes
----------
params
The upper bound of the ratio of remaining model parameters.
E.g., 0.6 means at most 60% parameters are kept while 40% parameters are pruned.
flops
The upper bound of the ratio of remaining model flops.
E.g., 0.6 means at most 60% flops are kept while 40% flops are pruned.
metric
The lower bound of the ratio of remaining model metric.
Metric is the evaluator's return value, usually it is a float number representing the model accuracy.
E.g., 0.9 means the compressed model should have at least 90% of the performance compared to the original model.
This means that if the accuracy of the original model is 80%, then the accuracy of the compressed model should
not be lower than 72% (0.9 * 80%).
module_types
The modules of the type in this list will be compressed.
module_names
The modules in this list will be compressed.
exclude_module_names
The modules in this list will not be compressed.
pruners
A list of `PrunerConfig`, possible pruner choices.
quantizers
A list of `QuantizerConfig`, possible quantizer choices.
"""
# constraints
params
:
Union
[
str
,
int
,
float
,
None
]
=
None
flops
:
Union
[
str
,
int
,
float
,
None
]
=
None
# latency: float | None
metric
:
Optional
[
float
]
=
None
# compress scope description
module_types
:
Optional
[
List
[
Union
[
Type
[
Module
],
str
]]]
=
None
module_names
:
Optional
[
List
[
str
]]
=
None
exclude_module_names
:
Optional
[
List
[
str
]]
=
None
# pruning algorithm description
pruners
:
Optional
[
List
[
PrunerConfig
]]
=
None
quantizers
:
Optional
[
List
[
QuantizerConfig
]]
=
None
@
dataclass
(
init
=
False
)
class
CompressionExperimentConfig
(
ExperimentConfig
):
compression_setting
:
CompressionConfig
def
__init__
(
self
,
training_service_platform
=
None
,
compression_setting
=
None
,
**
kwargs
):
super
().
__init__
(
training_service_platform
,
**
kwargs
)
if
compression_setting
:
self
.
compression_setting
=
compression_setting
else
:
self
.
compression_setting
=
CompressionConfig
()
nni/compression/experiment/config/pruner.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
,
asdict
from
typing_extensions
import
Literal
from
nni.experiment.config.base
import
ConfigBase
@
dataclass
class
PrunerConfig
(
ConfigBase
):
"""
Use to config the initialization parameters of a quantizer used in the compression experiment.
"""
pruner_type
:
Literal
[
'Pruner'
]
def
json
(
self
):
canon
=
self
.
canonical_copy
()
return
asdict
(
canon
)
@
dataclass
class
L1NormPrunerConfig
(
PrunerConfig
):
pruner_type
:
Literal
[
'L1NormPruner'
]
=
'L1NormPruner'
mode
:
Literal
[
'normal'
,
'dependency_aware'
]
=
'dependency_aware'
@
dataclass
class
TaylorFOWeightPrunerConfig
(
PrunerConfig
):
pruner_type
:
Literal
[
'TaylorFOWeightPruner'
]
=
'TaylorFOWeightPruner'
mode
:
Literal
[
'normal'
,
'dependency_aware'
,
'global'
]
=
'dependency_aware'
training_batches
:
int
=
30
nni/compression/experiment/config/quantizer.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
nni.experiment.config.base
import
ConfigBase
@
dataclass
class
QuantizerConfig
(
ConfigBase
):
"""
A placeholder for quantizer config.
Use to config the initialization parameters of a quantizer used in the compression experiment.
"""
pass
nni/compression/experiment/config/utils.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
from
typing
import
Any
,
Tuple
,
Dict
,
List
,
Type
from
torch.nn
import
Module
from
nni.compression.pytorch.utils
import
count_flops_params
from
.compression
import
CompressionConfig
from
.vessel
import
CompressionVessel
KEY_MODULE_NAME
=
'module_name::'
KEY_PRUNERS
=
'pruners'
KEY_VESSEL
=
'_vessel'
KEY_ORIGINAL_TARGET
=
'_original_target'
KEY_THETAS
=
'_thetas'
def
_flops_theta_helper
(
target
:
int
|
float
|
str
|
None
,
origin
:
int
)
->
Tuple
[
float
,
float
]:
# hard code and magic number for flops/params reward function
# the reward function is: sigmoid(flops_retained) = 1 / (1 + exp(-theta1 * (flops_retained + theta0)))
# this helper function return a theta pair (theta0, theta1) for building a suitable (maybe) function.
# the lower evaluating result (flops/params) compressed model has, the higher reward it gets.
if
not
target
or
(
isinstance
(
target
,
(
int
,
float
))
and
target
==
0
):
return
(
0.
,
0.
)
elif
isinstance
(
target
,
float
):
assert
0.
<
target
<
1.
return
(
-
0.1
-
target
,
-
50.
)
elif
isinstance
(
target
,
int
):
assert
0
<
target
<
origin
return
(
-
0.1
-
target
/
origin
,
-
50.
)
elif
isinstance
(
target
,
str
):
raise
NotImplementedError
(
'Currently only supports setting the upper bound with int/float.'
)
else
:
raise
TypeError
(
f
'Wrong target type:
{
type
(
target
).
__name__
}
, only support int/float/None.'
)
def
_metric_theta_helper
(
target
:
float
|
None
,
origin
:
float
)
->
Tuple
[
float
,
float
]:
# hard code and magic number for metric reward function
# only difference with `_flops_theta_helper` is the higher evaluating result (metric) compressed model has,
# the higher reward it gets.
if
not
target
:
return
(
-
0.85
,
50.
)
elif
isinstance
(
target
,
float
):
assert
0.
<=
target
<=
1.
return
(
0.1
-
target
,
50.
)
else
:
raise
TypeError
(
f
'Wrong target type:
{
type
(
target
).
__name__
}
, only support float/None.'
)
def
_summary_module_names
(
model
:
Module
,
module_types
:
List
[
Type
[
Module
]
|
str
],
module_names
:
List
[
str
],
exclude_module_names
:
List
[
str
])
->
List
[
str
]:
# Return a list of module names that need to be compressed.
# Include all names of modules that specified in `module_types` and `module_names` at first,
# then remove the names specified in `exclude_module_names`.
_module_types
=
set
()
_all_module_names
=
set
()
module_names_summary
=
set
()
if
module_types
:
for
module_type
in
module_types
:
if
isinstance
(
module_type
,
Module
):
module_type
=
module_type
.
__name__
assert
isinstance
(
module_type
,
str
)
_module_types
.
add
(
module_type
)
# unfold module types as module names, add them to summary
for
module_name
,
module
in
model
.
named_modules
():
module_type
=
type
(
module
).
__name__
if
module_type
in
_module_types
:
module_names_summary
.
add
(
module_name
)
_all_module_names
.
add
(
module_name
)
# add module names to summary
if
module_names
:
for
module_name
in
module_names
:
if
module_name
not
in
_all_module_names
:
# need warning, module_name not exist
continue
else
:
module_names_summary
.
add
(
module_name
)
# remove module names in exclude_module_names from module_names_summary
if
exclude_module_names
:
for
module_name
in
exclude_module_names
:
if
module_name
not
in
_all_module_names
:
# need warning, module_name not exist
continue
if
module_name
in
module_names_summary
:
module_names_summary
.
remove
(
module_name
)
return
list
(
module_names_summary
)
def
generate_compression_search_space
(
config
:
CompressionConfig
,
vessel
:
CompressionVessel
)
->
Dict
[
str
,
Dict
]:
"""
Using config (constraints & priori) and vessel (model-related) to generate the hpo search space.
"""
search_space
=
{}
model
,
_
,
evaluator
,
dummy_input
,
_
,
_
,
_
,
_
=
vessel
.
export
()
flops
,
params
,
results
=
count_flops_params
(
model
,
dummy_input
,
verbose
=
False
,
mode
=
'full'
)
metric
=
evaluator
(
model
)
module_names_summary
=
_summary_module_names
(
model
,
config
.
module_types
,
config
.
module_names
,
config
.
exclude_module_names
)
for
module_name
in
module_names_summary
:
search_space
[
'{}{}'
.
format
(
KEY_MODULE_NAME
,
module_name
)]
=
{
'_type'
:
'uniform'
,
'_value'
:
[
0
,
1
]}
assert
not
config
.
pruners
or
not
config
.
quantizers
# TODO: hard code for step 1, need refactor
search_space
[
KEY_PRUNERS
]
=
{
'_type'
:
'choice'
,
'_value'
:
[
pruner_config
.
json
()
for
pruner_config
in
config
.
pruners
]}
original_target
=
{
'flops'
:
flops
,
'params'
:
params
,
'metric'
:
metric
,
'results'
:
results
}
# TODO: following fucntion need improvement
flops_theta
=
_flops_theta_helper
(
config
.
flops
,
flops
)
params_theta
=
_flops_theta_helper
(
config
.
params
,
params
)
metric_theta
=
_metric_theta_helper
(
config
.
metric
,
metric
)
thetas
=
{
'flops'
:
flops_theta
,
'params'
:
params_theta
,
'metric'
:
metric_theta
}
search_space
[
KEY_VESSEL
]
=
{
'_type'
:
'choice'
,
'_value'
:
[
vessel
.
json
()]}
search_space
[
KEY_ORIGINAL_TARGET
]
=
{
'_type'
:
'choice'
,
'_value'
:
[
original_target
]}
search_space
[
KEY_THETAS
]
=
{
'_type'
:
'choice'
,
'_value'
:
[
thetas
]}
return
search_space
def
parse_params
(
kwargs
:
Dict
[
str
,
Any
])
->
Tuple
[
Dict
[
str
,
str
],
List
[
Dict
[
str
,
Any
]],
CompressionVessel
,
Dict
[
str
,
Any
],
Dict
[
str
,
Any
]]:
"""
Parse the parameters received by nni.get_next_parameter().
Returns
-------
Dict[str, str], List[Dict[str, Any]], CompressionVessel, Dict[str, Any], Dict[str, Any]
The compressor config, compressor config_list, model-related wrapper, evaluation value (flops, params, ...) for the original model,
parameters of the hpo objective function.
"""
compressor_config
,
vessel
,
original_target
,
thetas
=
None
,
None
,
None
,
None
config_list
=
[]
for
key
,
value
in
kwargs
.
items
():
if
key
.
startswith
(
KEY_MODULE_NAME
):
config_list
.
append
({
'op_names'
:
[
key
.
split
(
KEY_MODULE_NAME
)[
1
]],
'sparsity_per_layer'
:
float
(
value
)})
elif
key
==
KEY_PRUNERS
:
compressor_config
=
value
elif
key
==
KEY_VESSEL
:
vessel
=
CompressionVessel
(
**
value
)
elif
key
==
KEY_ORIGINAL_TARGET
:
original_target
=
value
elif
key
==
KEY_THETAS
:
thetas
=
value
else
:
raise
KeyError
(
'Unrecognized key {}'
.
format
(
key
))
return
compressor_config
,
config_list
,
vessel
,
original_target
,
thetas
def
parse_basic_pruner
(
pruner_config
:
Dict
[
str
,
str
],
config_list
:
List
[
Dict
[
str
,
Any
]],
vessel
:
CompressionVessel
):
"""
Parse basic pruner and model-related objects used by pruning scheduler.
"""
model
,
finetuner
,
evaluator
,
dummy_input
,
trainer
,
optimizer_helper
,
criterion
,
device
=
vessel
.
export
()
if
pruner_config
[
'pruner_type'
]
==
'L1NormPruner'
:
from
nni.compression.pytorch.pruning
import
L1NormPruner
basic_pruner
=
L1NormPruner
(
model
=
model
,
config_list
=
config_list
,
mode
=
pruner_config
[
'mode'
],
dummy_input
=
dummy_input
)
elif
pruner_config
[
'pruner_type'
]
==
'TaylorFOWeightPruner'
:
from
nni.compression.pytorch.pruning
import
TaylorFOWeightPruner
basic_pruner
=
TaylorFOWeightPruner
(
model
=
model
,
config_list
=
config_list
,
trainer
=
trainer
,
traced_optimizer
=
optimizer_helper
,
criterion
=
criterion
,
training_batches
=
pruner_config
[
'training_batches'
],
mode
=
pruner_config
[
'mode'
],
dummy_input
=
dummy_input
)
else
:
raise
ValueError
(
'Unsupported basic pruner type {}'
.
format
(
pruner_config
.
pruner_type
))
return
basic_pruner
,
model
,
finetuner
,
evaluator
,
dummy_input
,
device
nni/compression/experiment/config/vessel.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
base64
import
io
from
dataclasses
import
dataclass
,
asdict
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Optional
,
Tuple
,
Union
,
overload
import
torch
from
torch
import
Tensor
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
nni.algorithms.compression.v2.pytorch.utils.constructor_helper
import
OptimizerConstructHelper
from
nni.common
import
dump
,
load
from
nni.experiment.config.base
import
ConfigBase
@
dataclass
(
init
=
False
)
class
CompressionVessel
(
ConfigBase
):
"""
This is an internal class that helps serialize model-related parameters during model compression.
# FIXME: In fact, it is not a `Config`, the only reason it is a `Config` right now is that its data attribute
# will go into the search space as a single choice field. Need to refactor after the experiment config is stable.
"""
model
:
str
finetuner
:
str
evaluator
:
str
dummy_input
:
str
trainer
:
Optional
[
str
]
optimizer_helper
:
Optional
[
str
]
criterion
:
Optional
[
str
]
device
:
str
@
overload
def
__init__
(
self
,
model
:
str
,
finetuner
:
str
,
evaluator
:
str
,
dummy_input
:
str
,
trainer
:
str
,
optimizer_helper
:
str
,
criterion
:
str
,
device
:
str
):
...
@
overload
def
__init__
(
self
,
model
:
Module
,
finetuner
:
Callable
[[
Module
],
None
],
evaluator
:
Callable
[[
Module
],
float
],
dummy_input
:
Tensor
,
trainer
:
Optional
[
Callable
[[
Module
,
Optimizer
,
Callable
[[
Any
,
Any
],
Any
]],
None
]],
optimizer_helper
:
Union
[
Optimizer
,
OptimizerConstructHelper
,
None
],
criterion
:
Optional
[
Callable
[[
Any
,
Any
],
Any
]],
device
:
Union
[
str
,
torch
.
device
]):
...
def
__init__
(
self
,
model
:
Union
[
Module
,
str
],
finetuner
:
Union
[
Callable
[[
Module
],
None
],
str
],
evaluator
:
Union
[
Callable
[[
Module
],
float
],
str
],
dummy_input
:
Union
[
Tensor
,
str
],
trainer
:
Union
[
Callable
[[
Module
,
Optimizer
,
Callable
[[
Any
,
Any
],
Any
]],
None
],
str
,
None
],
optimizer_helper
:
Union
[
Optimizer
,
OptimizerConstructHelper
,
str
,
None
],
criterion
:
Union
[
Callable
[[
Any
,
Any
],
Any
],
str
,
None
],
device
:
Union
[
torch
.
device
,
str
]):
self
.
model
=
dump
(
model
)
if
not
isinstance
(
model
,
str
)
else
model
self
.
finetuner
=
dump
(
finetuner
)
if
not
isinstance
(
finetuner
,
str
)
else
finetuner
self
.
evaluator
=
dump
(
evaluator
)
if
not
isinstance
(
evaluator
,
str
)
else
evaluator
if
not
isinstance
(
dummy_input
,
str
):
buff
=
io
.
BytesIO
()
torch
.
save
(
dummy_input
,
buff
)
buff
.
seek
(
0
)
dummy_input
=
base64
.
b64encode
(
buff
.
read
()).
decode
()
self
.
dummy_input
=
dummy_input
self
.
trainer
=
dump
(
trainer
)
if
not
isinstance
(
trainer
,
str
)
else
trainer
if
not
isinstance
(
optimizer_helper
,
str
):
if
not
isinstance
(
optimizer_helper
,
OptimizerConstructHelper
):
optimizer_helper
=
OptimizerConstructHelper
.
from_trace
(
model
,
optimizer_helper
)
optimizer_helper
=
dump
(
optimizer_helper
)
self
.
optimizer_helper
=
optimizer_helper
self
.
criterion
=
dump
(
criterion
)
if
not
isinstance
(
criterion
,
str
)
else
criterion
self
.
device
=
str
(
device
)
def
export
(
self
)
->
Tuple
[
Module
,
Callable
[[
Module
],
None
],
Callable
[[
Module
],
float
],
Tensor
,
Optional
[
Callable
[[
Module
,
Optimizer
,
Callable
[[
Any
,
Any
],
Any
]],
None
]],
Optional
[
OptimizerConstructHelper
],
Optional
[
Callable
[[
Any
,
Any
],
Any
]],
torch
.
device
]:
device
=
torch
.
device
(
self
.
device
)
model
=
load
(
self
.
model
)
if
Path
(
'nni_outputs'
,
'checkpoint'
,
'model_state_dict.pth'
).
exists
():
model
.
load_state_dict
(
torch
.
load
(
Path
(
'nni_outputs'
,
'checkpoint'
,
'model_state_dict.pth'
)))
return
(
model
.
to
(
device
),
load
(
self
.
finetuner
),
load
(
self
.
evaluator
),
torch
.
load
(
io
.
BytesIO
(
base64
.
b64decode
(
self
.
dummy_input
.
encode
()))).
to
(
device
),
load
(
self
.
trainer
),
load
(
self
.
optimizer_helper
),
load
(
self
.
criterion
),
device
)
def
json
(
self
):
canon
=
self
.
canonical_copy
()
return
asdict
(
canon
)
nni/compression/experiment/experiment.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
__future__
import
annotations
import
logging
from
pathlib
import
Path
import
shutil
import
tempfile
from
typing
import
Any
,
Callable
,
List
import
torch
from
torch.nn
import
Module
from
torch.optim
import
Optimizer
from
nni.compression.experiment.config
import
generate_compression_search_space
from
nni.experiment
import
Experiment
from
.config
import
CompressionExperimentConfig
,
CompressionVessel
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
class
CompressionExperiment
(
Experiment
):
"""
Note: This is an experimental feature, the interface is not stable.
Parameters
----------
config_or_platform
A `CompressionExperimentConfig` or the training service name or list of the training service name or None.
model
The pytorch model wanted to compress.
finetuner
The finetuner handled all finetune logic, use a pytorch module as input.
evaluator
Evaluate the pruned model and give a score.
dummy_input
It is used by `torch.jit.trace` to trace the model.
trainer
A callable function used to train model or just inference. Take model, optimizer, criterion as input.
Note that the model should only trained or inferenced one epoch in the trainer.
Example::
def trainer(model: Module, optimizer: Optimizer, criterion: Callable[[Tensor, Tensor], Tensor]):
training = model.training
model.train(mode=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = criterion(output, target)
loss.backward()
# If you don't want to update the model, you can skip `optimizer.step()`, and set train mode False.
optimizer.step()
model.train(mode=training)
optimizer
The traced optimizer instance which the optimizer class is wrapped by nni.trace.
E.g. ``traced_optimizer = nni.trace(torch.nn.Adam)(model.parameters())``.
criterion
The criterion function used in trainer. Take model output and target value as input, and return the loss.
device
The selected device.
"""
# keep this interface for now, will change after support lightning
def
__init__
(
self
,
config_or_platform
:
CompressionExperimentConfig
|
str
|
List
[
str
]
|
None
,
model
:
Module
,
finetuner
:
Callable
[[
Module
],
None
],
evaluator
:
Callable
[[
Module
],
float
],
dummy_input
:
Any
|
None
,
trainer
:
Callable
[[
Module
,
Optimizer
,
Callable
[[
Any
,
Any
],
Any
]],
None
]
|
None
,
optimizer
:
Optimizer
|
None
,
criterion
:
Callable
[[
Any
,
Any
],
Any
]
|
None
,
device
:
str
|
torch
.
device
):
super
().
__init__
(
config_or_platform
)
# have some risks if Experiment change its __init__, but work well for current version
self
.
config
:
CompressionExperimentConfig
|
None
=
None
if
isinstance
(
config_or_platform
,
(
str
,
list
)):
self
.
config
=
CompressionExperimentConfig
(
config_or_platform
)
else
:
self
.
config
=
config_or_platform
assert
all
([
model
,
finetuner
,
evaluator
])
assert
all
([
trainer
,
optimizer
,
criterion
])
or
not
any
([
trainer
,
optimizer
,
criterion
])
self
.
temp_directory
=
tempfile
.
mkdtemp
(
prefix
=
'nni_compression_{}_'
.
format
(
self
.
id
))
torch
.
save
(
model
.
state_dict
(),
Path
(
self
.
temp_directory
,
'model_state_dict.pth'
))
self
.
vessel
=
CompressionVessel
(
model
,
finetuner
,
evaluator
,
dummy_input
,
trainer
,
optimizer
,
criterion
,
device
)
def
start
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
# TODO: python3 is not robust, need support in nni manager
self
.
config
.
trial_command
=
'python3 -m nni.compression.experiment.trial_entry'
# copy files in temp directory to nni_outputs/checkpoint
# TODO: copy files to code dir is a temporary solution, need nnimanager support upload multi-directory,
# or package additional files when uploading.
checkpoint_dir
=
Path
(
self
.
config
.
trial_code_directory
,
'nni_outputs'
,
'checkpoint'
)
shutil
.
copytree
(
self
.
temp_directory
,
checkpoint_dir
,
dirs_exist_ok
=
True
)
if
self
.
config
.
search_space
or
self
.
config
.
search_space_file
:
_logger
.
warning
(
'Manual configuration of search_space is not recommended in compression experiments. %s'
,
'Please make sure you know what will happen.'
)
else
:
self
.
config
.
search_space
=
generate_compression_search_space
(
self
.
config
.
compression_setting
,
self
.
vessel
)
return
super
().
start
(
port
,
debug
)
nni/compression/experiment/trial_entry.py
0 → 100644
View file @
21539654
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
TODO: split this file to several modules
"""
import
math
import
os
from
pathlib
import
Path
import
nni
from
nni.algorithms.compression.v2.pytorch.pruning
import
PruningScheduler
from
nni.algorithms.compression.v2.pytorch.pruning.tools
import
AGPTaskGenerator
from
nni.compression.pytorch.utils
import
count_flops_params
from
.config.utils
import
parse_params
,
parse_basic_pruner
# TODO: move this function to evaluate module
def
sigmoid
(
x
:
float
,
theta0
:
float
=
-
0.5
,
theta1
:
float
=
10
)
->
float
:
return
1
/
(
1
+
math
.
exp
(
-
theta1
*
(
x
+
theta0
)))
if
__name__
==
'__main__'
:
kwargs
=
nni
.
get_next_parameter
()
pruner_config
,
config_list
,
vessel
,
original_target
,
thetas
=
parse_params
(
kwargs
)
basic_pruner
,
model
,
finetuner
,
evaluator
,
dummy_input
,
device
=
parse_basic_pruner
(
pruner_config
,
config_list
,
vessel
)
# TODO: move following logic to excution engine
log_dir
=
Path
(
os
.
environ
[
'NNI_OUTPUT_DIR'
])
if
'NNI_OUTPUT_DIR'
in
os
.
environ
else
Path
(
'nni_outputs'
,
'log'
)
task_generator
=
AGPTaskGenerator
(
total_iteration
=
3
,
origin_model
=
model
,
origin_config_list
=
config_list
,
skip_first_iteration
=
True
,
log_dir
=
log_dir
)
speedup
=
dummy_input
is
not
None
scheduler
=
PruningScheduler
(
pruner
=
basic_pruner
,
task_generator
=
task_generator
,
finetuner
=
finetuner
,
speedup
=
speedup
,
dummy_input
=
dummy_input
,
evaluator
=
None
)
scheduler
.
compress
()
_
,
model
,
_
,
_
,
_
=
scheduler
.
get_best_result
()
metric
=
evaluator
(
model
)
flops
,
params
,
_
=
count_flops_params
(
model
,
dummy_input
,
verbose
=
False
,
mode
=
'full'
)
# TODO: more efficient way to calculate or combine these scores
flops_score
=
sigmoid
(
flops
/
original_target
[
'flops'
],
*
thetas
[
'flops'
])
params_score
=
sigmoid
(
params
/
original_target
[
'params'
],
*
thetas
[
'params'
])
metric_score
=
sigmoid
(
metric
/
original_target
[
'metric'
],
*
thetas
[
'metric'
])
final_result
=
flops_score
+
params_score
+
metric_score
nni
.
report_final_result
({
'default'
:
final_result
,
'flops'
:
flops
,
'params'
:
params
,
'metric'
:
metric
})
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment