Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
75028bd7
Unverified
Commit
75028bd7
authored
Mar 17, 2020
by
SparkSnail
Committed by
GitHub
Mar 17, 2020
Browse files
Merge pull request #235 from microsoft/master
merge master
parents
1d74ae5e
2e42d1d8
Changes
94
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
660 additions
and
332 deletions
+660
-332
src/nni_manager/training_service/dlts/dltsTrialConfig.ts
src/nni_manager/training_service/dlts/dltsTrialConfig.ts
+15
-0
src/nni_manager/training_service/dlts/dltsTrialJobDetail.ts
src/nni_manager/training_service/dlts/dltsTrialJobDetail.ts
+31
-0
src/sdk/pynni/nni/compression/torch/__init__.py
src/sdk/pynni/nni/compression/torch/__init__.py
+1
-1
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
...i/nni/compression/torch/activation_rank_filter_pruners.py
+38
-50
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+117
-91
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+57
-57
src/sdk/pynni/nni/compression/torch/quantizers.py
src/sdk/pynni/nni/compression/torch/quantizers.py
+31
-27
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+34
-24
src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py
.../pynni/nni/curvefitting_assessor/curvefitting_assessor.py
+5
-3
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
+5
-12
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
+134
-0
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+3
-1
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+86
-0
src/sdk/pynni/nni/platform/__init__.py
src/sdk/pynni/nni/platform/__init__.py
+1
-1
src/sdk/pynni/nni/utils.py
src/sdk/pynni/nni/utils.py
+30
-1
src/sdk/pynni/tests/test_compressor.py
src/sdk/pynni/tests/test_compressor.py
+40
-48
src/webui/src/components/Modals/CustomizedTrial.tsx
src/webui/src/components/Modals/CustomizedTrial.tsx
+10
-5
src/webui/src/components/Overview.tsx
src/webui/src/components/Overview.tsx
+4
-4
src/webui/src/components/overview/SuccessTable.tsx
src/webui/src/components/overview/SuccessTable.tsx
+17
-7
No files found.
src/nni_manager/training_service/dlts/dltsTrialConfig.ts
0 → 100644
View file @
75028bd7
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import
{
TrialConfig
}
from
"
training_service/common/trialConfig
"
;
export
class
DLTSTrialConfig
extends
TrialConfig
{
public
constructor
(
command
:
string
,
codeDir
:
string
,
gpuNum
:
number
,
public
readonly
image
:
string
)
{
super
(
command
,
codeDir
,
gpuNum
);
}
}
src/nni_manager/training_service/dlts/dltsTrialJobDetail.ts
0 → 100644
View file @
75028bd7
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.
import
{
TrialJobDetail
,
TrialJobStatus
,
TrialJobApplicationForm
}
from
"
../../common/trainingService
"
;
export
class
DLTSTrialJobDetail
implements
TrialJobDetail
{
public
startTime
?:
number
;
public
endTime
?:
number
;
public
tags
?:
string
[];
public
url
?:
string
;
public
isEarlyStopped
?:
boolean
;
// DLTS staff
public
dltsJobId
?:
string
;
public
dltsPaused
:
boolean
=
false
;
public
constructor
(
public
id
:
string
,
public
status
:
TrialJobStatus
,
public
submitTime
:
number
,
public
workingDirectory
:
string
,
public
form
:
TrialJobApplicationForm
,
// DLTS staff
public
dltsJobName
:
string
,
)
{}
}
src/sdk/pynni/nni/compression/torch/__init__.py
View file @
75028bd7
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.compressor
import
LayerInfo
,
Compressor
,
Pruner
,
Quantizer
from
.compressor
import
Compressor
,
Pruner
,
Quantizer
from
.pruners
import
*
from
.pruners
import
*
from
.weight_rank_filter_pruners
import
*
from
.weight_rank_filter_pruners
import
*
from
.activation_rank_filter_pruners
import
*
from
.activation_rank_filter_pruners
import
*
...
...
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
View file @
75028bd7
...
@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -16,7 +16,7 @@ class ActivationRankFilterPruner(Pruner):
to achieve a preset level of network sparsity.
to achieve a preset level of network sparsity.
"""
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -25,17 +25,23 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -25,17 +25,23 @@ class ActivationRankFilterPruner(Pruner):
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
activation : str
Activation function
Activation function
statistics_batch_num : int
statistics_batch_num : int
Num of batches for activation statistics
Num of batches for activation statistics
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
self
.
set_wrappers_attribute
(
"collected_activation"
,
[])
self
.
statistics_batch_num
=
statistics_batch_num
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
hooks
=
{}
def
collector
(
module_
,
input_
,
output
):
if
len
(
module_
.
collected_activation
)
<
self
.
statistics_batch_num
:
module_
.
collected_activation
.
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
self
.
add_activation_collector
(
collector
)
assert
activation
in
[
'relu'
,
'relu6'
]
assert
activation
in
[
'relu'
,
'relu6'
]
if
activation
==
'relu'
:
if
activation
==
'relu'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu
self
.
activation
=
torch
.
nn
.
functional
.
relu
...
@@ -44,33 +50,10 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -44,33 +50,10 @@ class ActivationRankFilterPruner(Pruner):
else
:
else
:
self
.
activation
=
None
self
.
activation
=
None
def
compress
(
self
):
"""
Compress the model, register a hook for collecting activations.
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
wrapper
.
module
.
register_forward_hook
(
_hook
)
self
.
_wrap_model
()
return
self
.
bound_model
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Filters with the smallest importance criterion which is calculated from the activation are masked.
...
@@ -88,29 +71,30 @@ class ActivationRankFilterPruner(Pruner):
...
@@ -88,29 +71,30 @@ class ActivationRankFilterPruner(Pruner):
dictionary for storing masks
dictionary for storing masks
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
wrapper
.
module
.
weight
.
data
op_type
=
layer
.
type
op_type
=
wrapper
.
type
config
=
wrapper
.
config
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
assert
op_type
in
config
.
get
(
'op_types'
)
if_calculated
=
kwargs
[
"if_calculated"
]
if
if_calculated
:
if
wrapper
.
if_calculated
:
return
None
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
lay
er
.
module
,
'bias'
)
and
lay
er
.
module
.
bias
is
not
None
:
if
hasattr
(
wrapp
er
.
module
,
'bias'
)
and
wrapp
er
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
lay
er
.
module
.
bias
.
size
()).
type_as
(
lay
er
.
module
.
bias
).
detach
()
mask_bias
=
torch
.
ones
(
wrapp
er
.
module
.
bias
.
size
()).
type_as
(
wrapp
er
.
module
.
bias
).
detach
()
else
:
else
:
mask_bias
=
None
mask_bias
=
None
mask
=
{
'weight'
:
mask_weight
,
'bias'
:
mask_bias
}
mask
=
{
'weight
_mask
'
:
mask_weight
,
'bias
_mask
'
:
mask_bias
}
try
:
try
:
filters
=
weight
.
size
(
0
)
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
or
len
(
self
.
collected_activation
[
layer
.
name
]
)
<
self
.
statistics_batch_num
:
if
filters
<
2
or
num_prune
<
1
or
len
(
wrapper
.
collected_activation
)
<
self
.
statistics_batch_num
:
return
mask
return
mask
mask
=
self
.
get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
]
,
num_prune
)
mask
=
self
.
get_mask
(
mask
,
wrapper
.
collected_activation
,
num_prune
)
finally
:
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
]
)
==
self
.
statistics_batch_num
:
if
len
(
wrapper
.
collected_activation
)
==
self
.
statistics_batch_num
:
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Tru
e
return
mask
return
mask
...
@@ -123,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
...
@@ -123,7 +107,7 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1607.03250
https://arxiv.org/abs/1607.03250
"""
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -132,12 +116,14 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
...
@@ -132,12 +116,14 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
activation : str
Activation function
Activation function
statistics_batch_num : int
statistics_batch_num : int
Num of batches for activation statistics
Num of batches for activation statistics
"""
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
activation
,
statistics_batch_num
)
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
"""
...
@@ -161,9 +147,9 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
...
@@ -161,9 +147,9 @@ class ActivationAPoZRankFilterPruner(ActivationRankFilterPruner):
apoz
=
self
.
_calc_apoz
(
activations
)
apoz
=
self
.
_calc_apoz
(
activations
)
prune_indices
=
torch
.
argsort
(
apoz
,
descending
=
True
)[:
num_prune
]
prune_indices
=
torch
.
argsort
(
apoz
,
descending
=
True
)[:
num_prune
]
for
idx
in
prune_indices
:
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
base_mask
[
'weight
_mask
'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
if
base_mask
[
'bias
_mask
'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
base_mask
[
'bias
_mask
'
][
idx
]
=
0.
return
base_mask
return
base_mask
def
_calc_apoz
(
self
,
activations
):
def
_calc_apoz
(
self
,
activations
):
...
@@ -195,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
...
@@ -195,7 +181,7 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
https://arxiv.org/abs/1611.06440
https://arxiv.org/abs/1611.06440
"""
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -204,12 +190,14 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
...
@@ -204,12 +190,14 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
activation : str
activation : str
Activation function
Activation function
statistics_batch_num : int
statistics_batch_num : int
Num of batches for activation statistics
Num of batches for activation statistics
"""
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
super
().
__init__
(
model
,
config_list
,
optimizer
,
activation
,
statistics_batch_num
)
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
"""
...
@@ -233,9 +221,9 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
...
@@ -233,9 +221,9 @@ class ActivationMeanRankFilterPruner(ActivationRankFilterPruner):
mean_activation
=
self
.
_cal_mean_activation
(
activations
)
mean_activation
=
self
.
_cal_mean_activation
(
activations
)
prune_indices
=
torch
.
argsort
(
mean_activation
)[:
num_prune
]
prune_indices
=
torch
.
argsort
(
mean_activation
)[:
num_prune
]
for
idx
in
prune_indices
:
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
base_mask
[
'weight
_mask
'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
if
base_mask
[
'bias
_mask
'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
base_mask
[
'bias
_mask
'
][
idx
]
=
0.
return
base_mask
return
base_mask
def
_cal_mean_activation
(
self
,
activations
):
def
_cal_mean_activation
(
self
,
activations
):
...
...
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
75028bd7
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
types
import
logging
import
logging
import
torch
import
torch
from
.
import
default_layers
from
.
import
default_layers
...
@@ -20,12 +21,13 @@ def _setattr(model, name, module):
...
@@ -20,12 +21,13 @@ def _setattr(model, name, module):
model
=
getattr
(
model
,
name
)
model
=
getattr
(
model
,
name
)
setattr
(
model
,
name_list
[
-
1
],
module
)
setattr
(
model
,
name_list
[
-
1
],
module
)
class
Compressor
:
class
Compressor
:
"""
"""
Abstract base PyTorch compressor
Abstract base PyTorch compressor
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Record necessary info in class members
Record necessary info in class members
...
@@ -35,15 +37,27 @@ class Compressor:
...
@@ -35,15 +37,27 @@ class Compressor:
the model user wants to compress
the model user wants to compress
config_list : list
config_list : list
the configurations that users specify for compression
the configurations that users specify for compression
optimizer: pytorch optimizer
optimizer used to train the model
"""
"""
self
.
bound_model
=
model
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
config_list
=
config_list
self
.
optimizer
=
optimizer
self
.
modules_to_compress
=
None
self
.
modules_to_compress
=
None
self
.
modules_wrapper
=
None
self
.
modules_wrapper
=
[]
self
.
buffers
=
{}
self
.
is_wrapped
=
False
self
.
is_wrapped
=
False
def
detect_modules_to_compress
(
self
):
self
.
_fwd_hook_handles
=
{}
self
.
_fwd_hook_id
=
0
for
layer
,
config
in
self
.
_detect_modules_to_compress
():
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
_wrap_model
()
def
_detect_modules_to_compress
(
self
):
"""
"""
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
detect all modules should be compressed, and save the result in `self.modules_to_compress`.
The model will be instrumented and user should never edit it after calling this method.
The model will be instrumented and user should never edit it after calling this method.
...
@@ -87,26 +101,26 @@ class Compressor:
...
@@ -87,26 +101,26 @@ class Compressor:
torch.nn.Module
torch.nn.Module
model with specified modules compressed.
model with specified modules compressed.
"""
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
_wrap_model
()
return
self
.
bound_model
return
self
.
bound_model
def
register_buffer
(
self
,
name
,
value
):
def
set_wrappers_attribute
(
self
,
name
,
value
):
"""
"""
To register buffers used in wrapped module's forward method.
To register attributes used in wrapped module's forward method.
If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper,
which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper.
Parameters
----------
name : str
name of the variable
value: any
value of the variable
"""
"""
self
.
buffers
[
name
]
=
value
for
wrapper
in
self
.
get_modules_wrapper
():
if
isinstance
(
value
,
torch
.
Tensor
):
wrapper
.
register_buffer
(
name
,
value
.
clone
())
else
:
setattr
(
wrapper
,
name
,
value
)
def
get_modules_to_compress
(
self
):
def
get_modules_to_compress
(
self
):
"""
"""
...
@@ -180,11 +194,7 @@ class Compressor:
...
@@ -180,11 +194,7 @@ class Compressor:
epoch : num
epoch : num
the current epoch number
the current epoch number
"""
"""
pass
def
step
(
self
):
"""
If user want to update model every step, user can override this method
"""
def
_wrap_modules
(
self
,
layer
,
config
):
def
_wrap_modules
(
self
,
layer
,
config
):
"""
"""
...
@@ -200,6 +210,34 @@ class Compressor:
...
@@ -200,6 +210,34 @@ class Compressor:
raise
NotImplementedError
()
raise
NotImplementedError
()
def
add_activation_collector
(
self
,
collector
):
self
.
_fwd_hook_id
+=
1
self
.
_fwd_hook_handles
[
self
.
_fwd_hook_id
]
=
[]
for
wrapper
in
self
.
get_modules_wrapper
():
handle
=
wrapper
.
register_forward_hook
(
collector
)
self
.
_fwd_hook_handles
[
self
.
_fwd_hook_id
].
append
(
handle
)
return
self
.
_fwd_hook_id
def
remove_activation_collector
(
self
,
fwd_hook_id
):
if
fwd_hook_id
not
in
self
.
_fwd_hook_handles
:
raise
ValueError
(
"%s is not a valid collector id"
%
str
(
fwd_hook_id
))
for
handle
in
self
.
_fwd_hook_handles
[
fwd_hook_id
]:
handle
.
remove
()
del
self
.
_fwd_hook_handles
[
fwd_hook_id
]
def
patch_optimizer
(
self
,
*
tasks
):
def
patch_step
(
old_step
):
def
new_step
(
_
,
*
args
,
**
kwargs
):
# call origin optimizer step method
output
=
old_step
(
*
args
,
**
kwargs
)
# calculate mask
for
task
in
tasks
:
task
()
return
output
return
new_step
if
self
.
optimizer
is
not
None
:
self
.
optimizer
.
step
=
types
.
MethodType
(
patch_step
(
self
.
optimizer
.
step
),
self
.
optimizer
)
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
"""
"""
...
@@ -226,7 +264,6 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -226,7 +264,6 @@ class PrunerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
pruner
=
pruner
self
.
pruner
=
pruner
self
.
registered_buffers
=
[]
# register buffer for mask
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
...
@@ -234,29 +271,11 @@ class PrunerModuleWrapper(torch.nn.Module):
...
@@ -234,29 +271,11 @@ class PrunerModuleWrapper(torch.nn.Module):
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
.
append
(
'weight_mask'
)
self
.
registered_buffers
.
append
(
'bias_mask'
)
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
get_registered_buffers
())
# apply mask to weight, bias
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
# apply mask to bias
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
mask
is
not
None
and
'bias'
in
mask
:
self
.
bias_mask
.
copy_
(
mask
[
'bias'
])
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
return
self
.
module
(
*
inputs
)
...
@@ -272,10 +291,24 @@ class Pruner(Compressor):
...
@@ -272,10 +291,24 @@ class Pruner(Compressor):
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
if
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
update_mask
)
def
compress
(
self
):
self
.
update_mask
()
return
self
.
bound_model
def
update_mask
(
self
):
for
wrapper
in
self
.
get_modules_wrapper
():
masks
=
self
.
calc_mask
(
wrapper
)
if
masks
is
not
None
:
for
k
in
masks
:
assert
hasattr
(
wrapper
,
k
),
"there is no attribute '%s' in wrapper"
%
k
setattr
(
wrapper
,
k
,
masks
[
k
])
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Pruners should overload this method to provide mask for weight tensors.
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
The mask must have the same shape and type comparing to the weight.
...
@@ -284,10 +317,8 @@ class Pruner(Compressor):
...
@@ -284,10 +317,8 @@ class Pruner(Compressor):
Parameters
Parameters
----------
----------
layer : LayerInfo
wrapper : Module
calculate mask for `layer`'s weight
calculate mask for `wrapper.module`'s weight
config : dict
the configuration for generating the mask
"""
"""
raise
NotImplementedError
(
"Pruners must overload calc_mask()"
)
raise
NotImplementedError
(
"Pruners must overload calc_mask()"
)
...
@@ -327,8 +358,6 @@ class Pruner(Compressor):
...
@@ -327,8 +358,6 @@ class Pruner(Compressor):
device of the model, used to place the dummy input tensor for exporting onnx file.
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
the tensor is placed on cpu if ```device``` is None
"""
"""
# if self.detect_modules_to_compress() and not self.mask_dict:
# _logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert
model_path
is
not
None
,
'model_path must be specified'
assert
model_path
is
not
None
,
'model_path must be specified'
mask_dict
=
{}
mask_dict
=
{}
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
...
@@ -404,7 +433,6 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -404,7 +433,6 @@ class QuantizerModuleWrapper(torch.nn.Module):
# config and pruner
# config and pruner
self
.
config
=
config
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
quantizer
=
quantizer
self
.
registered_buffers
=
[]
# register buffer and parameter
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# old_weight is used to store origin weight and weight is used to store quantized weight
...
@@ -418,35 +446,18 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -418,35 +446,18 @@ class QuantizerModuleWrapper(torch.nn.Module):
delattr
(
self
.
module
,
'weight'
)
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
.
apply
(
inputs
=
self
.
quantizer
.
quant_grad
.
apply
(
inputs
,
inputs
,
QuantType
.
QUANT_INPUT
,
QuantType
.
QUANT_INPUT
,
self
.
quantizer
.
quantize_input
,
self
)
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
get_registered_buffers
())
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
self
.
module
.
old_weight
,
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantizer
.
quantize_weight
,
self
)
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
get_registered_buffers
())
self
.
module
.
weight
=
new_weight
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
result
=
self
.
module
(
*
inputs
)
else
:
else
:
...
@@ -456,10 +467,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
...
@@ -456,10 +467,7 @@ class QuantizerModuleWrapper(torch.nn.Module):
result
=
self
.
quantizer
.
quant_grad
.
apply
(
result
=
self
.
quantizer
.
quant_grad
.
apply
(
result
,
result
,
QuantType
.
QUANT_OUTPUT
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantizer
.
quantize_output
,
self
)
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
get_registered_buffers
())
return
result
return
result
class
Quantizer
(
Compressor
):
class
Quantizer
(
Compressor
):
...
@@ -467,11 +475,18 @@ class Quantizer(Compressor):
...
@@ -467,11 +475,18 @@ class Quantizer(Compressor):
Base quantizer for pytorch quantizer
Base quantizer for pytorch quantizer
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QuantGrad
self
.
quant_grad
=
QuantGrad
if
self
.
optimizer
is
not
None
:
self
.
patch_optimizer
(
self
.
step_with_optimizer
)
for
wrapper
in
self
.
get_modules_wrapper
():
if
'weight'
in
wrapper
.
config
[
'quant_types'
]:
# old_weight is registered to keep track of weight before quantization
# and it is trainable, therefore, it should be added to optimizer.
self
.
optimizer
.
add_param_group
({
"params"
:
wrapper
.
module
.
old_weight
})
def
quantize_weight
(
self
,
weight
,
config
,
op
,
op_ty
pe
,
op_name
):
def
quantize_weight
(
self
,
weight
,
wrap
pe
r
,
**
kwargs
):
"""
"""
quantize should overload this method to quantize weight.
quantize should overload this method to quantize weight.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
...
@@ -479,12 +494,12 @@ class Quantizer(Compressor):
...
@@ -479,12 +494,12 @@ class Quantizer(Compressor):
----------
----------
weight : Tensor
weight : Tensor
weight that needs to be quantized
weight that needs to be quantized
config : dict
wrapper : QuantizerModuleWrapper
the
configuration for weight quantization
the
wrapper for origin module
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_weight()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_weight()'
)
def
quantize_output
(
self
,
output
,
config
,
op
,
op_ty
pe
,
op_name
):
def
quantize_output
(
self
,
output
,
wrap
pe
r
,
**
kwargs
):
"""
"""
quantize should overload this method to quantize output.
quantize should overload this method to quantize output.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
...
@@ -492,12 +507,12 @@ class Quantizer(Compressor):
...
@@ -492,12 +507,12 @@ class Quantizer(Compressor):
----------
----------
output : Tensor
output : Tensor
output that needs to be quantized
output that needs to be quantized
config : dict
wrapper : QuantizerModuleWrapper
the
configuration for output quantization
the
wrapper for origin module
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_output()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_output()'
)
def
quantize_input
(
self
,
*
inputs
,
config
,
op
,
op_ty
pe
,
op_name
):
def
quantize_input
(
self
,
*
inputs
,
wrap
pe
r
,
**
kwargs
):
"""
"""
quantize should overload this method to quantize input.
quantize should overload this method to quantize input.
This method is effectively hooked to :meth:`forward` of the model.
This method is effectively hooked to :meth:`forward` of the model.
...
@@ -505,8 +520,8 @@ class Quantizer(Compressor):
...
@@ -505,8 +520,8 @@ class Quantizer(Compressor):
----------
----------
inputs : Tensor
inputs : Tensor
inputs that needs to be quantized
inputs that needs to be quantized
config : dict
wrapper : QuantizerModuleWrapper
the
configuration for inputs quantization
the
wrapper for origin module
"""
"""
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
...
@@ -532,6 +547,9 @@ class Quantizer(Compressor):
...
@@ -532,6 +547,9 @@ class Quantizer(Compressor):
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
def
step_with_optimizer
(
self
):
pass
class
QuantType
:
class
QuantType
:
"""
"""
Enum class for quantization type.
Enum class for quantization type.
...
@@ -540,6 +558,7 @@ class QuantType:
...
@@ -540,6 +558,7 @@ class QuantType:
QUANT_WEIGHT
=
1
QUANT_WEIGHT
=
1
QUANT_OUTPUT
=
2
QUANT_OUTPUT
=
2
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
Base class for overriding backward function of quantization operation.
Base class for overriding backward function of quantization operation.
...
@@ -566,15 +585,22 @@ class QuantGrad(torch.autograd.Function):
...
@@ -566,15 +585,22 @@ class QuantGrad(torch.autograd.Function):
return
grad_output
return
grad_output
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
lay
er
,
**
kwargs
):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapp
er
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
,
**
kwargs
)
if
quant_type
==
QuantType
.
QUANT_INPUT
:
return
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
return
wrapper
.
quantizer
.
quantize_weight
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
return
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
raise
ValueError
(
"unrecognized QuantType."
)
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
return
output
,
None
,
None
,
None
,
None
,
None
return
output
,
None
,
None
,
None
def
_check_weight
(
module
):
def
_check_weight
(
module
):
try
:
try
:
...
...
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
75028bd7
...
@@ -16,7 +16,7 @@ class LevelPruner(Pruner):
...
@@ -16,7 +16,7 @@ class LevelPruner(Pruner):
Prune to an exact pruning level specification
Prune to an exact pruning level specification
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -24,38 +24,39 @@ class LevelPruner(Pruner):
...
@@ -24,38 +24,39 @@ class LevelPruner(Pruner):
Model to be pruned
Model to be pruned
config_list : list
config_list : list
List on pruning configs
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer
Calculate the mask of given layer
Parameters
Parameters
----------
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
the module to instrument the compression operation
config : dict
layer's pruning config
Returns
Returns
-------
-------
dict
dict
dictionary for storing masks
dictionary for storing masks
"""
"""
weight
=
layer
.
module
.
weight
.
data
config
=
wrapper
.
config
if_calculated
=
kwargs
[
"if_calculated"
]
weight
=
wrapper
.
module
.
weight
.
data
if
not
if_calculated
:
if
not
wrapper
.
if_calculated
:
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
mask
=
{
'weight
_mask
'
:
mask_weight
}
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Tru
e
return
mask
return
mask
else
:
else
:
return
None
return
None
...
@@ -71,7 +72,7 @@ class AGP_Pruner(Pruner):
...
@@ -71,7 +72,7 @@ class AGP_Pruner(Pruner):
https://arxiv.org/pdf/1710.01878.pdf
https://arxiv.org/pdf/1710.01878.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -79,50 +80,51 @@ class AGP_Pruner(Pruner):
...
@@ -79,50 +80,51 @@ class AGP_Pruner(Pruner):
Model to be pruned
Model to be pruned
config_list : list
config_list : list
List on pruning configs
List on pruning configs
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"AGP pruner is an iterative pruner, please pass optimizer of the model to it"
self
.
now_epoch
=
0
self
.
now_epoch
=
0
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
Parameters
----------
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
Returns
-------
-------
dict
dict
dictionary for storing masks
dictionary for storing masks
"""
"""
weight
=
layer
.
module
.
weight
.
data
config
=
wrapper
.
config
weight
=
wrapper
.
module
.
weight
.
data
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
freq
=
config
.
get
(
'frequency'
,
1
)
if_calculated
=
kwargs
[
"if_calculated"
]
if
wrapper
.
if_calculated
:
if
if_calculated
:
return
None
return
None
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
return
None
return
None
mask
=
{
'weight
'
:
kwargs
[
'weight_mask'
]
if
'weight_mask'
in
kwargs
else
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
}
mask
=
{
'weight
_mask'
:
wrapper
.
weight_mask
}
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
return
mask
# if we want to generate new mask, we should update weigth first
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
[
'weight'
]
w_abs
=
weight
.
abs
()
*
mask
[
'weight
_mask
'
]
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
new_mask
=
{
'weight
_mask
'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Tru
e
return
new_mask
return
new_mask
...
@@ -180,62 +182,64 @@ class SlimPruner(Pruner):
...
@@ -180,62 +182,64 @@ class SlimPruner(Pruner):
https://arxiv.org/pdf/1708.06519.pdf
https://arxiv.org/pdf/1708.06519.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Parameters
Parameters
----------
----------
model : torch.nn.module
Model to be pruned
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
weight_list
=
[]
weight_list
=
[]
if
len
(
config_list
)
>
1
:
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detec
t_modules_to_compress
():
for
(
layer
,
config
)
in
self
.
ge
t_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
Parameters
----------
----------
layer : LayerInfo
wrapper : Module
the layer to instrument the compression operation
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
Returns
-------
-------
dict
dict
dictionary for storing masks
dictionary for storing masks
"""
"""
weight
=
layer
.
module
.
weight
.
data
config
=
wrapper
.
config
op_type
=
layer
.
type
weight
=
wrapper
.
module
.
weight
.
data
if_calculated
=
kwargs
[
"if_calculated"
]
op_type
=
wrapper
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
if_calculated
:
if
wrapper
.
if_calculated
:
return
None
return
None
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
mask
=
{
'weight
_mask
'
:
base_mask
.
detach
(),
'bias
_mask
'
:
base_mask
.
clone
().
detach
()}
filters
=
weight
.
size
(
0
)
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
>=
2
and
num_prune
>=
1
:
if
filters
>=
2
and
num_prune
>=
1
:
w_abs
=
weight
.
abs
()
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
mask
=
{
'weight
_mask
'
:
mask_weight
.
detach
(),
'bias
_mask
'
:
mask_bias
.
detach
()}
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Tru
e
return
mask
return
mask
class
LotteryTicketPruner
(
Pruner
):
class
LotteryTicketPruner
(
Pruner
):
...
@@ -250,7 +254,7 @@ class LotteryTicketPruner(Pruner):
...
@@ -250,7 +254,7 @@ class LotteryTicketPruner(Pruner):
5. Repeat step 2, 3, and 4.
5. Repeat step 2, 3, and 4.
"""
"""
def
__init__
(
self
,
model
,
config_list
,
optimizer
,
lr_scheduler
=
None
,
reset_weights
=
True
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
,
lr_scheduler
=
None
,
reset_weights
=
True
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -267,7 +271,7 @@ class LotteryTicketPruner(Pruner):
...
@@ -267,7 +271,7 @@ class LotteryTicketPruner(Pruner):
reset_weights : bool
reset_weights : bool
Whether reset weights and optimizer at the beginning of each round.
Whether reset weights and optimizer at the beginning of each round.
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
curr_prune_iteration
=
None
self
.
curr_prune_iteration
=
None
self
.
prune_iterations
=
self
.
_validate_config
(
config_list
)
self
.
prune_iterations
=
self
.
_validate_config
(
config_list
)
...
@@ -307,20 +311,16 @@ class LotteryTicketPruner(Pruner):
...
@@ -307,20 +311,16 @@ class LotteryTicketPruner(Pruner):
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask
}
return
{
'weight
_mask
'
:
mask
}
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Generate mask for the given ``weight``.
Generate mask for the given ``weight``.
Parameters
Parameters
----------
----------
layer : LayerInfo
wrapper : Module
The layer to be pruned
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
Returns
-------
-------
...
@@ -355,7 +355,7 @@ class LotteryTicketPruner(Pruner):
...
@@ -355,7 +355,7 @@ class LotteryTicketPruner(Pruner):
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_to_compress
=
self
.
detec
t_modules_to_compress
()
modules_to_compress
=
self
.
ge
t_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
module_wrapper
=
None
module_wrapper
=
None
for
wrapper
in
modules_wrapper
:
for
wrapper
in
modules_wrapper
:
...
@@ -367,7 +367,7 @@ class LotteryTicketPruner(Pruner):
...
@@ -367,7 +367,7 @@ class LotteryTicketPruner(Pruner):
sparsity
=
config
.
get
(
'sparsity'
)
sparsity
=
config
.
get
(
'sparsity'
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
# TODO: directly use weight_mask is not good
# TODO: directly use weight_mask is not good
module_wrapper
.
weight_mask
.
copy_
(
mask
[
'weight'
]
)
module_wrapper
.
weight_mask
=
mask
[
'weight
_mask
'
]
# there is no mask for bias
# there is no mask for bias
# reinit weights back to original after new masks are generated
# reinit weights back to original after new masks are generated
...
...
src/sdk/pynni/nni/compression/torch/quantizers.py
View file @
75028bd7
...
@@ -13,14 +13,14 @@ logger = logging.getLogger(__name__)
...
@@ -13,14 +13,14 @@ logger = logging.getLogger(__name__)
class
NaiveQuantizer
(
Quantizer
):
class
NaiveQuantizer
(
Quantizer
):
"""quantize weight to 8 bits
"""quantize weight to 8 bits
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
layer_scale
=
{}
self
.
layer_scale
=
{}
def
quantize_weight
(
self
,
weight
,
config
,
op_name
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
new_scale
=
weight
.
abs
().
max
()
/
127
new_scale
=
weight
.
abs
().
max
()
/
127
scale
=
max
(
self
.
layer_scale
.
get
(
op_
name
,
0
),
new_scale
)
scale
=
max
(
self
.
layer_scale
.
get
(
wrapper
.
name
,
0
),
new_scale
)
self
.
layer_scale
[
op_
name
]
=
scale
self
.
layer_scale
[
wrapper
.
name
]
=
scale
orig_type
=
weight
.
type
()
# TODO: user layer
orig_type
=
weight
.
type
()
# TODO: user layer
return
weight
.
div
(
scale
).
type
(
torch
.
int8
).
type
(
orig_type
).
mul
(
scale
)
return
weight
.
div
(
scale
).
type
(
torch
.
int8
).
type
(
orig_type
).
mul
(
scale
)
...
@@ -104,7 +104,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -104,7 +104,7 @@ class QAT_Quantizer(Quantizer):
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
http://openaccess.thecvf.com/content_cvpr_2018/papers/Jacob_Quantization_and_Training_CVPR_2018_paper.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -124,9 +124,9 @@ class QAT_Quantizer(Quantizer):
...
@@ -124,9 +124,9 @@ class QAT_Quantizer(Quantizer):
- op_types : list of string
- op_types : list of string
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
steps
=
1
self
.
steps
=
1
modules_to_compress
=
self
.
detec
t_modules_to_compress
()
modules_to_compress
=
self
.
ge
t_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
layer
.
module
.
register_buffer
(
"zero_point"
,
None
)
layer
.
module
.
register_buffer
(
"zero_point"
,
None
)
layer
.
module
.
register_buffer
(
"scale"
,
None
)
layer
.
module
.
register_buffer
(
"scale"
,
None
)
...
@@ -181,7 +181,9 @@ class QAT_Quantizer(Quantizer):
...
@@ -181,7 +181,9 @@ class QAT_Quantizer(Quantizer):
real_val
=
op
.
scale
*
(
quantized_val
-
op
.
zero_point
)
real_val
=
op
.
scale
*
(
quantized_val
-
op
.
zero_point
)
return
real_val
return
real_val
def
quantize_weight
(
self
,
weight
,
config
,
op
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
weight_bits
=
get_bits_length
(
config
,
'weight'
)
weight_bits
=
get_bits_length
(
config
,
'weight'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
assert
weight_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -189,12 +191,14 @@ class QAT_Quantizer(Quantizer):
...
@@ -189,12 +191,14 @@ class QAT_Quantizer(Quantizer):
if
quant_start_step
>
self
.
steps
:
if
quant_start_step
>
self
.
steps
:
return
weight
return
weight
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
rmin
,
rmax
=
torch
.
min
(
weight
),
torch
.
max
(
weight
)
op
.
scale
,
op
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
weight_bits
,
rmin
,
rmax
)
out
=
self
.
_quantize
(
weight_bits
,
op
,
weight
)
out
=
self
.
_quantize
(
weight_bits
,
module
,
weight
)
out
=
self
.
_dequantize
(
op
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
return
out
def
quantize_output
(
self
,
output
,
config
,
op
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
config
=
wrapper
.
config
module
=
wrapper
.
module
output_bits
=
get_bits_length
(
config
,
'output'
)
output_bits
=
get_bits_length
(
config
,
'output'
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
quant_start_step
=
config
.
get
(
'quant_start_step'
,
0
)
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
...
@@ -203,18 +207,18 @@ class QAT_Quantizer(Quantizer):
...
@@ -203,18 +207,18 @@ class QAT_Quantizer(Quantizer):
return
output
return
output
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
op
.
tracked_min_biased
,
op
.
tracked_min
=
update_ema
(
op
.
tracked_min_biased
,
current_min
,
op
.
ema_decay
,
self
.
steps
)
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
,
self
.
steps
)
op
.
tracked_max_biased
,
op
.
tracked_max
=
update_ema
(
op
.
tracked_max_biased
,
current_max
,
op
.
ema_decay
,
self
.
steps
)
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
,
self
.
steps
)
op
.
scale
,
op
.
zero_point
=
update_quantization_param
(
output_bits
,
op
.
tracked_min
,
op
.
tracked_max
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
,
module
.
tracked_max
)
out
=
self
.
_quantize
(
output_bits
,
op
,
output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
op
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
return
out
def
fold_bn
(
self
,
config
,
**
kwargs
):
def
fold_bn
(
self
,
config
,
**
kwargs
):
# TODO simulate folded weight
# TODO simulate folded weight
pass
pass
def
step
(
self
):
def
step
_with_optimizer
(
self
):
"""
"""
override `compressor` `step` method, quantization only happens after certain number of steps
override `compressor` `step` method, quantization only happens after certain number of steps
"""
"""
...
@@ -226,11 +230,11 @@ class DoReFaQuantizer(Quantizer):
...
@@ -226,11 +230,11 @@ class DoReFaQuantizer(Quantizer):
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
Zhou et al., DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients
(https://arxiv.org/abs/1606.06160)
(https://arxiv.org/abs/1606.06160)
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
def
quantize_weight
(
self
,
weight
,
config
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
weight_bits
=
get_bits_length
(
config
,
'weight'
)
weight_bits
=
get_bits_length
(
wrapper
.
config
,
'weight'
)
out
=
weight
.
tanh
()
out
=
weight
.
tanh
()
out
=
out
/
(
2
*
out
.
abs
().
max
())
+
0.5
out
=
out
/
(
2
*
out
.
abs
().
max
())
+
0.5
out
=
self
.
quantize
(
out
,
weight_bits
)
out
=
self
.
quantize
(
out
,
weight_bits
)
...
@@ -256,17 +260,17 @@ class BNNQuantizer(Quantizer):
...
@@ -256,17 +260,17 @@ class BNNQuantizer(Quantizer):
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
Binarized Neural Networks: Training Deep Neural Networks with Weights and Activations Constrained to +1 or -1
(https://arxiv.org/abs/1602.02830)
(https://arxiv.org/abs/1602.02830)
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
ClipGrad
self
.
quant_grad
=
ClipGrad
def
quantize_weight
(
self
,
weight
,
config
,
**
kwargs
):
def
quantize_weight
(
self
,
weight
,
wrapper
,
**
kwargs
):
out
=
torch
.
sign
(
weight
)
out
=
torch
.
sign
(
weight
)
# remove zeros
# remove zeros
out
[
out
==
0
]
=
1
out
[
out
==
0
]
=
1
return
out
return
out
def
quantize_output
(
self
,
output
,
config
,
**
kwargs
):
def
quantize_output
(
self
,
output
,
wrapper
,
**
kwargs
):
out
=
torch
.
sign
(
output
)
out
=
torch
.
sign
(
output
)
# remove zeros
# remove zeros
out
[
out
==
0
]
=
1
out
[
out
==
0
]
=
1
...
...
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
75028bd7
...
@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
...
@@ -15,7 +15,7 @@ class WeightRankFilterPruner(Pruner):
importance criterion in convolution layers to achieve a preset level of network sparsity.
importance criterion in convolution layers to achieve a preset level of network sparsity.
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -24,15 +24,17 @@ class WeightRankFilterPruner(Pruner):
...
@@ -24,15 +24,17 @@ class WeightRankFilterPruner(Pruner):
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
set_wrappers_attribute
(
"if_calculated"
,
False
)
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
def
calc_mask
(
self
,
wrapper
,
**
kwargs
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
Filters with the smallest importance criterion of the kernel weights are masked.
...
@@ -48,20 +50,21 @@ class WeightRankFilterPruner(Pruner):
...
@@ -48,20 +50,21 @@ class WeightRankFilterPruner(Pruner):
dictionary for storing masks
dictionary for storing masks
"""
"""
weight
=
layer
.
module
.
weight
.
data
weight
=
wrapper
.
module
.
weight
.
data
op_type
=
layer
.
type
op_type
=
wrapper
.
type
config
=
wrapper
.
config
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
],
"only support Conv1d and Conv2d"
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
],
"only support Conv1d and Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
assert
op_type
in
config
.
get
(
'op_types'
)
if_calculated
=
kwargs
[
"if_calculated"
]
if
if_calculated
:
if
wrapper
.
if_calculated
:
return
None
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
lay
er
.
module
,
'bias'
)
and
lay
er
.
module
.
bias
is
not
None
:
if
hasattr
(
wrapp
er
.
module
,
'bias'
)
and
wrapp
er
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
lay
er
.
module
.
bias
.
size
()).
type_as
(
lay
er
.
module
.
bias
).
detach
()
mask_bias
=
torch
.
ones
(
wrapp
er
.
module
.
bias
.
size
()).
type_as
(
wrapp
er
.
module
.
bias
).
detach
()
else
:
else
:
mask_bias
=
None
mask_bias
=
None
mask
=
{
'weight'
:
mask_weight
,
'bias'
:
mask_bias
}
mask
=
{
'weight
_mask
'
:
mask_weight
,
'bias
_mask
'
:
mask_bias
}
try
:
try
:
filters
=
weight
.
size
(
0
)
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
...
@@ -69,7 +72,7 @@ class WeightRankFilterPruner(Pruner):
...
@@ -69,7 +72,7 @@ class WeightRankFilterPruner(Pruner):
return
mask
return
mask
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
finally
:
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Tru
e
return
mask
return
mask
...
@@ -82,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner):
...
@@ -82,7 +85,7 @@ class L1FilterPruner(WeightRankFilterPruner):
https://arxiv.org/abs/1608.08710
https://arxiv.org/abs/1608.08710
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -91,9 +94,11 @@ class L1FilterPruner(WeightRankFilterPruner):
...
@@ -91,9 +94,11 @@ class L1FilterPruner(WeightRankFilterPruner):
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
...
@@ -121,7 +126,7 @@ class L1FilterPruner(WeightRankFilterPruner):
...
@@ -121,7 +126,7 @@ class L1FilterPruner(WeightRankFilterPruner):
mask_weight
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_abs_structured
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_abs_structured
,
threshold
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_abs_structured
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
return
{
'weight
_mask
'
:
mask_weight
.
detach
(),
'bias
_mask
'
:
mask_bias
.
detach
()}
class
L2FilterPruner
(
WeightRankFilterPruner
):
class
L2FilterPruner
(
WeightRankFilterPruner
):
...
@@ -130,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner):
...
@@ -130,7 +135,7 @@ class L2FilterPruner(WeightRankFilterPruner):
smallest L2 norm of the weights.
smallest L2 norm of the weights.
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
=
None
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -139,9 +144,11 @@ class L2FilterPruner(WeightRankFilterPruner):
...
@@ -139,9 +144,11 @@ class L2FilterPruner(WeightRankFilterPruner):
config_list : list
config_list : list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
...
@@ -167,7 +174,7 @@ class L2FilterPruner(WeightRankFilterPruner):
...
@@ -167,7 +174,7 @@ class L2FilterPruner(WeightRankFilterPruner):
mask_weight
=
torch
.
gt
(
w_l2_norm
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_weight
=
torch
.
gt
(
w_l2_norm
,
threshold
)[:,
None
,
None
,
None
].
expand_as
(
weight
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_l2_norm
,
threshold
).
type_as
(
weight
)
mask_bias
=
torch
.
gt
(
w_l2_norm
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
return
{
'weight
_mask
'
:
mask_weight
.
detach
(),
'bias
_mask
'
:
mask_bias
.
detach
()}
class
FPGMPruner
(
WeightRankFilterPruner
):
class
FPGMPruner
(
WeightRankFilterPruner
):
...
@@ -177,7 +184,7 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -177,7 +184,7 @@ class FPGMPruner(WeightRankFilterPruner):
https://arxiv.org/pdf/1811.00250.pdf
https://arxiv.org/pdf/1811.00250.pdf
"""
"""
def
__init__
(
self
,
model
,
config_list
):
def
__init__
(
self
,
model
,
config_list
,
optimizer
):
"""
"""
Parameters
Parameters
----------
----------
...
@@ -186,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -186,8 +193,11 @@ class FPGMPruner(WeightRankFilterPruner):
config_list: list
config_list: list
support key for each list item:
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
- sparsity: percentage of convolutional filters to be pruned.
optimizer: torch.optim.Optimizer
Optimizer used to train model
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
assert
isinstance
(
optimizer
,
torch
.
optim
.
Optimizer
),
"FPGM pruner is an iterative pruner, please pass optimizer of the model to it"
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
...
@@ -208,9 +218,9 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -208,9 +218,9 @@ class FPGMPruner(WeightRankFilterPruner):
"""
"""
min_gm_idx
=
self
.
_get_min_gm_kernel_idx
(
weight
,
num_prune
)
min_gm_idx
=
self
.
_get_min_gm_kernel_idx
(
weight
,
num_prune
)
for
idx
in
min_gm_idx
:
for
idx
in
min_gm_idx
:
base_mask
[
'weight'
][
idx
]
=
0.
base_mask
[
'weight
_mask
'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
if
base_mask
[
'bias
_mask
'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
base_mask
[
'bias
_mask
'
][
idx
]
=
0.
return
base_mask
return
base_mask
def
_get_min_gm_kernel_idx
(
self
,
weight
,
n
):
def
_get_min_gm_kernel_idx
(
self
,
weight
,
n
):
...
@@ -258,4 +268,4 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -258,4 +268,4 @@ class FPGMPruner(WeightRankFilterPruner):
def
update_epoch
(
self
,
epoch
):
def
update_epoch
(
self
,
epoch
):
for
wrapper
in
self
.
get_modules_wrapper
():
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callabl
e
wrapper
.
if_calculated
=
Fals
e
src/sdk/pynni/nni/curvefitting_assessor/curvefitting_assessor.py
View file @
75028bd7
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
import
logging
import
logging
import
datetime
import
datetime
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.utils
import
extract_scalar_history
from
.model_factory
import
CurveModel
from
.model_factory
import
CurveModel
logger
=
logging
.
getLogger
(
'curvefitting_Assessor'
)
logger
=
logging
.
getLogger
(
'curvefitting_Assessor'
)
...
@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
...
@@ -91,10 +92,11 @@ class CurvefittingAssessor(Assessor):
Exception
Exception
unrecognize exception in curvefitting_assessor
unrecognize exception in curvefitting_assessor
"""
"""
self
.
trial_history
=
trial_history
scalar_trial_history
=
extract_scalar_history
(
trial_history
)
self
.
trial_history
=
scalar_trial_history
if
not
self
.
set_best_performance
:
if
not
self
.
set_best_performance
:
return
AssessResult
.
Good
return
AssessResult
.
Good
curr_step
=
len
(
trial_history
)
curr_step
=
len
(
scalar_
trial_history
)
if
curr_step
<
self
.
start_step
:
if
curr_step
<
self
.
start_step
:
return
AssessResult
.
Good
return
AssessResult
.
Good
...
@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
...
@@ -106,7 +108,7 @@ class CurvefittingAssessor(Assessor):
start_time
=
datetime
.
datetime
.
now
()
start_time
=
datetime
.
datetime
.
now
()
# Predict the final result
# Predict the final result
curvemodel
=
CurveModel
(
self
.
target_pos
)
curvemodel
=
CurveModel
(
self
.
target_pos
)
predict_y
=
curvemodel
.
predict
(
trial_history
)
predict_y
=
curvemodel
.
predict
(
scalar_
trial_history
)
logger
.
info
(
'Prediction done. Trial job id = %s. Predict value = %s'
,
trial_job_id
,
predict_y
)
logger
.
info
(
'Prediction done. Trial job id = %s. Predict value = %s'
,
trial_job_id
,
predict_y
)
if
predict_y
is
None
:
if
predict_y
is
None
:
logger
.
info
(
'wait for more information to predict precisely'
)
logger
.
info
(
'wait for more information to predict precisely'
)
...
...
src/sdk/pynni/nni/medianstop_assessor/medianstop_assessor.py
View file @
75028bd7
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
import
logging
import
logging
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.assessor
import
Assessor
,
AssessResult
from
nni.utils
import
extract_scalar_history
logger
=
logging
.
getLogger
(
'medianstop_Assessor'
)
logger
=
logging
.
getLogger
(
'medianstop_Assessor'
)
...
@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
...
@@ -91,20 +92,12 @@ class MedianstopAssessor(Assessor):
if
curr_step
<
self
.
_start_step
:
if
curr_step
<
self
.
_start_step
:
return
AssessResult
.
Good
return
AssessResult
.
Good
try
:
scalar_trial_history
=
extract_scalar_history
(
trial_history
)
num_trial_history
=
[
float
(
ele
)
for
ele
in
trial_history
]
self
.
_update_data
(
trial_job_id
,
scalar_trial_history
)
except
(
TypeError
,
ValueError
)
as
error
:
logger
.
warning
(
'incorrect data type or value:'
)
logger
.
exception
(
error
)
except
Exception
as
error
:
logger
.
warning
(
'unrecognized exception in medianstop_assessor:'
)
logger
.
exception
(
error
)
self
.
_update_data
(
trial_job_id
,
num_trial_history
)
if
self
.
_high_better
:
if
self
.
_high_better
:
best_history
=
max
(
trial_history
)
best_history
=
max
(
scalar_
trial_history
)
else
:
else
:
best_history
=
min
(
trial_history
)
best_history
=
min
(
scalar_
trial_history
)
avg_array
=
[]
avg_array
=
[]
for
id_
in
self
.
_completed_avg_history
:
for
id_
in
self
.
_completed_avg_history
:
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
75028bd7
...
@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
...
@@ -234,4 +234,5 @@ class MsgDispatcher(MsgDispatcherBase):
if
multi_thread_enabled
():
if
multi_thread_enabled
():
self
.
_handle_final_metric_data
(
data
)
self
.
_handle_final_metric_data
(
data
)
else
:
else
:
data
[
'value'
]
=
to_json
(
data
[
'value'
])
self
.
enqueue_command
(
CommandType
.
ReportMetricData
,
data
)
self
.
enqueue_command
(
CommandType
.
ReportMetricData
,
data
)
src/sdk/pynni/nni/nas/pytorch/_graph_utils.py
0 → 100644
View file @
75028bd7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# This file is copied from PyTorch 1.4, with bug fixes.
# Likely to be removed in future.
import
torch
from
tensorboard.compat.proto.config_pb2
import
RunMetadata
from
tensorboard.compat.proto.graph_pb2
import
GraphDef
from
tensorboard.compat.proto.step_stats_pb2
import
StepStats
,
DeviceStepStats
from
tensorboard.compat.proto.versions_pb2
import
VersionDef
from
torch.utils.tensorboard._pytorch_graph
import
GraphPy
,
CLASSTYPE_KIND
,
GETATTR_KIND
,
NodePyIO
,
NodePyOP
def
parse
(
graph
,
trace
,
args
=
None
,
omit_useless_nodes
=
True
):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
n_inputs
=
len
(
args
)
scope
=
{}
nodes_py
=
GraphPy
()
for
node
in
graph
.
inputs
():
if
omit_useless_nodes
:
if
len
(
node
.
uses
())
==
0
:
# number of user of the node (= number of outputs/ fanout)
continue
if
node
.
type
().
kind
()
!=
CLASSTYPE_KIND
:
nodes_py
.
append
(
NodePyIO
(
node
,
'input'
))
attr_to_scope
=
dict
()
node_to_name
=
lambda
d
:
str
(
d
).
split
(
":"
)[
0
].
strip
()
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
GETATTR_KIND
:
attr_name
=
node
.
s
(
'name'
)
node_name
=
node_to_name
(
node
)
parent
=
node
.
input
().
node
()
if
parent
.
kind
()
==
GETATTR_KIND
:
# If the parent node is not the top-level "self" node
parent_attr_name
=
parent
.
s
(
'name'
)
parent_scope
=
attr_to_scope
[
node_to_name
(
parent
)]
attr_scope
=
parent_scope
.
split
(
'/'
)[
-
1
]
attr_to_scope
[
node_name
]
=
'{}/{}.{}'
.
format
(
parent_scope
,
attr_scope
,
attr_name
)
else
:
attr_to_scope
[
node_name
]
=
'__module.{}'
.
format
(
attr_name
)
# We don't need classtype nodes; scope will provide this information
if
node
.
output
().
type
().
kind
()
!=
CLASSTYPE_KIND
:
node_py
=
NodePyOP
(
node
)
node_py
.
scopeName
=
attr_to_scope
[
node_name
]
nodes_py
.
append
(
node_py
)
else
:
nodes_py
.
append
(
NodePyOP
(
node
))
for
i
,
node
in
enumerate
(
graph
.
outputs
()):
# Create sink nodes for output ops
node_py
=
NodePyIO
(
node
,
'output'
)
node_py
.
debugName
=
"output.{}"
.
format
(
i
+
1
)
node_py
.
inputs
=
[
node
.
debugName
()]
nodes_py
.
append
(
node_py
)
def
parse_traced_name
(
module_name
):
prefix
=
'TracedModule['
suffix
=
']'
if
module_name
.
startswith
(
prefix
)
and
module_name
.
endswith
(
suffix
):
module_name
=
module_name
[
len
(
prefix
):
-
len
(
suffix
)]
return
module_name
alias_to_name
=
dict
()
base_name
=
parse_traced_name
(
trace
.
_name
)
for
name
,
module
in
trace
.
named_modules
(
prefix
=
'__module'
):
mod_name
=
parse_traced_name
(
module
.
_name
)
attr_name
=
name
.
split
(
'.'
)[
-
1
]
alias_to_name
[
name
]
=
'{}[{}]'
.
format
(
mod_name
,
attr_name
)
for
node
in
nodes_py
.
nodes_op
:
module_aliases
=
node
.
scopeName
.
split
(
'/'
)[
-
1
].
split
(
'.'
)
module_name
=
''
for
i
,
alias
in
enumerate
(
module_aliases
):
if
i
==
0
:
module_name
=
alias
node
.
scopeName
=
base_name
else
:
module_name
+=
'.'
+
alias
node
.
scopeName
+=
'/'
+
(
alias_to_name
[
module_name
]
if
module_name
in
alias_to_name
else
alias
)
nodes_py
.
populate_namespace_from_OP_to_IO
()
return
nodes_py
.
to_proto
()
def
graph
(
model
,
args
,
verbose
=
False
):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
with
torch
.
onnx
.
set_training
(
model
,
False
):
# TODO: move outside of torch.onnx?
try
:
trace
=
torch
.
jit
.
trace
(
model
,
args
)
graph
=
trace
.
graph
torch
.
_C
.
_jit_pass_inline
(
graph
)
except
RuntimeError
as
e
:
print
(
e
)
print
(
'Error occurs, No graph saved'
)
raise
e
if
verbose
:
print
(
graph
)
list_of_nodes
=
parse
(
graph
,
trace
,
args
)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats
=
RunMetadata
(
step_stats
=
StepStats
(
dev_stats
=
[
DeviceStepStats
(
device
=
"/device:CPU:0"
)]))
return
GraphDef
(
node
=
list_of_nodes
,
versions
=
VersionDef
(
producer
=
22
)),
stepstats
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
75028bd7
...
@@ -23,7 +23,9 @@ class StackedLSTMCell(nn.Module):
...
@@ -23,7 +23,9 @@ class StackedLSTMCell(nn.Module):
curr_c
,
curr_h
=
m
(
inputs
,
(
prev_c
[
i
],
prev_h
[
i
]))
curr_c
,
curr_h
=
m
(
inputs
,
(
prev_c
[
i
],
prev_h
[
i
]))
next_c
.
append
(
curr_c
)
next_c
.
append
(
curr_c
)
next_h
.
append
(
curr_h
)
next_h
.
append
(
curr_h
)
inputs
=
curr_h
[
-
1
]
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs
=
curr_h
[
-
1
].
view
(
1
,
-
1
)
return
next_c
,
next_h
return
next_c
,
next_h
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
75028bd7
...
@@ -2,7 +2,9 @@
...
@@ -2,7 +2,9 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
from
collections
import
defaultdict
import
numpy
as
np
import
torch
import
torch
from
nni.nas.pytorch.base_mutator
import
BaseMutator
from
nni.nas.pytorch.base_mutator
import
BaseMutator
...
@@ -15,6 +17,7 @@ class Mutator(BaseMutator):
...
@@ -15,6 +17,7 @@ class Mutator(BaseMutator):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
(
model
)
super
().
__init__
(
model
)
self
.
_cache
=
dict
()
self
.
_cache
=
dict
()
self
.
_connect_all
=
False
def
sample_search
(
self
):
def
sample_search
(
self
):
"""
"""
...
@@ -57,6 +60,74 @@ class Mutator(BaseMutator):
...
@@ -57,6 +60,74 @@ class Mutator(BaseMutator):
"""
"""
return
self
.
sample_final
()
return
self
.
sample_final
()
def
status
(
self
):
"""
Return current selection status of mutator.
Returns
-------
dict
A mapping from key of mutables to decisions. All weights (boolean type and float type)
are converted into real number values. Numpy arrays and tensors are converted into list.
"""
data
=
dict
()
for
k
,
v
in
self
.
_cache
.
items
():
if
torch
.
is_tensor
(
v
):
v
=
v
.
detach
().
cpu
().
numpy
()
if
isinstance
(
v
,
np
.
ndarray
):
v
=
v
.
astype
(
np
.
float32
).
tolist
()
data
[
k
]
=
v
return
data
def
graph
(
self
,
inputs
):
"""
Return model supernet graph.
Parameters
----------
inputs: tuple of tensor
Inputs that will be feeded into the network.
Returns
-------
dict
Containing ``node``, in Tensorboard GraphDef format.
Additional key ``mutable`` is a map from key to list of modules.
"""
if
not
torch
.
__version__
.
startswith
(
"1.4"
):
logger
.
warning
(
"Graph is only tested with PyTorch 1.4. Other versions might not work."
)
from
._graph_utils
import
graph
from
google.protobuf
import
json_format
# protobuf should be installed as long as tensorboard is installed
try
:
self
.
_connect_all
=
True
graph_def
,
_
=
graph
(
self
.
model
,
inputs
,
verbose
=
False
)
result
=
json_format
.
MessageToDict
(
graph_def
)
finally
:
self
.
_connect_all
=
False
# `mutable` is to map the keys to a list of corresponding modules.
# A key can be linked to multiple modules, use `dedup=False` to find them all.
result
[
"mutable"
]
=
defaultdict
(
list
)
for
mutable
in
self
.
mutables
.
traverse
(
deduplicate
=
False
):
# A module will be represent in the format of
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
# This format is aligned with the scope name jit gives.
modules
=
mutable
.
name
.
split
(
"."
)
path
=
[
{
"type"
:
self
.
model
.
__class__
.
__name__
,
"name"
:
""
}
]
m
=
self
.
model
for
module
in
modules
:
m
=
getattr
(
m
,
module
)
path
.
append
({
"type"
:
m
.
__class__
.
__name__
,
"name"
:
module
})
result
[
"mutable"
][
mutable
.
key
].
append
(
path
)
return
result
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
inputs
):
"""
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
On default, this method retrieves the decision obtained previously, and select certain operations.
...
@@ -75,6 +146,11 @@ class Mutator(BaseMutator):
...
@@ -75,6 +146,11 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
tuple of torch.Tensor and torch.Tensor
Output and mask.
Output and mask.
"""
"""
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
[
op
(
*
inputs
)
for
op
in
mutable
.
choices
]),
\
torch
.
ones
(
mutable
.
length
)
def
_map_fn
(
op
,
*
inputs
):
def
_map_fn
(
op
,
*
inputs
):
return
op
(
*
inputs
)
return
op
(
*
inputs
)
...
@@ -101,6 +177,9 @@ class Mutator(BaseMutator):
...
@@ -101,6 +177,9 @@ class Mutator(BaseMutator):
tuple of torch.Tensor and torch.Tensor
tuple of torch.Tensor and torch.Tensor
Output and mask.
Output and mask.
"""
"""
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
tensor_list
),
\
torch
.
ones
(
mutable
.
n_candidates
)
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
assert
len
(
mask
)
==
mutable
.
n_candidates
,
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
mutable
.
n_candidates
)
...
@@ -131,6 +210,13 @@ class Mutator(BaseMutator):
...
@@ -131,6 +210,13 @@ class Mutator(BaseMutator):
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
raise
ValueError
(
"Unrecognized reduction policy:
\"
{}
\"
"
.
format
(
reduction_type
))
def
_all_connect_tensor_reduction
(
self
,
reduction_type
,
tensor_list
):
if
reduction_type
==
"none"
:
return
tensor_list
if
reduction_type
==
"concat"
:
return
torch
.
cat
(
tensor_list
,
dim
=
1
)
return
torch
.
stack
(
tensor_list
).
sum
(
0
)
def
_get_decision
(
self
,
mutable
):
def
_get_decision
(
self
,
mutable
):
"""
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
By default, this method checks whether `mutable.key` is already in the decision cache,
...
...
src/sdk/pynni/nni/platform/__init__.py
View file @
75028bd7
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
):
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
):
from
.local
import
*
from
.local
import
*
else
:
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
src/sdk/pynni/nni/utils.py
View file @
75028bd7
...
@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
...
@@ -62,6 +62,13 @@ def extract_scalar_reward(value, scalar_key='default'):
"""
"""
Extract scalar reward from trial result.
Extract scalar reward from trial result.
Parameters
----------
value : int, float, dict
the reported final metric data
scalar_key : str
the key name that indicates the numeric number
Raises
Raises
------
------
RuntimeError
RuntimeError
...
@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
...
@@ -78,6 +85,26 @@ def extract_scalar_reward(value, scalar_key='default'):
return
reward
return
reward
def
extract_scalar_history
(
trial_history
,
scalar_key
=
'default'
):
"""
Extract scalar value from a list of intermediate results.
Parameters
----------
trial_history : list
accumulated intermediate results of a trial
scalar_key : str
the key name that indicates the numeric number
Raises
------
RuntimeError
Incorrect final result: the final result should be float/int,
or a dict which has a key named "default" whose value is float/int.
"""
return
[
extract_scalar_reward
(
ele
,
scalar_key
)
for
ele
in
trial_history
]
def
convert_dict2tuple
(
value
):
def
convert_dict2tuple
(
value
):
"""
"""
convert dict type to tuple to solve unhashable problem.
convert dict type to tuple to solve unhashable problem.
...
@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
...
@@ -90,7 +117,9 @@ def convert_dict2tuple(value):
def
init_dispatcher_logger
():
def
init_dispatcher_logger
():
""" Initialize dispatcher logging configuration"""
"""
Initialize dispatcher logging configuration
"""
logger_file_path
=
'dispatcher.log'
logger_file_path
=
'dispatcher.log'
if
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
is
not
None
:
if
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
is
not
None
:
logger_file_path
=
os
.
path
.
join
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
,
logger_file_path
)
logger_file_path
=
os
.
path
.
join
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
,
logger_file_path
)
...
...
src/sdk/pynni/tests/test_compressor.py
View file @
75028bd7
This diff is collapsed.
Click to expand it.
src/webui/src/components/Modals/CustomizedTrial.tsx
View file @
75028bd7
...
@@ -59,29 +59,34 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
...
@@ -59,29 +59,34 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
});
});
// true: parameters are wrong
// true: parameters are wrong
let
flag
=
false
;
let
parametersIllegal
=
false
;
Object
.
keys
(
customized
).
map
(
item
=>
{
Object
.
keys
(
customized
).
map
(
item
=>
{
if
(
item
!==
'
tag
'
)
{
if
(
item
!==
'
tag
'
)
{
// unified data type
// unified data type
if
(
typeof
copyTrialParameter
[
item
]
===
'
number
'
&&
typeof
customized
[
item
]
===
'
string
'
)
{
if
(
typeof
copyTrialParameter
[
item
]
===
'
number
'
&&
typeof
customized
[
item
]
===
'
string
'
)
{
customized
[
item
]
=
JSON
.
parse
(
customized
[
item
]);
customized
[
item
]
=
JSON
.
parse
(
customized
[
item
]);
}
}
if
(
searchSpace
[
item
]
===
undefined
)
{
// sometimes the schema of trial parameters is different from search space
// e.g. Batch Tuner
return
;
}
if
(
searchSpace
[
item
].
_type
===
'
choice
'
)
{
if
(
searchSpace
[
item
].
_type
===
'
choice
'
)
{
if
(
searchSpace
[
item
].
_value
.
find
((
val
:
string
|
number
)
=>
if
(
searchSpace
[
item
].
_value
.
find
((
val
:
string
|
number
)
=>
val
===
customized
[
item
])
===
undefined
)
{
val
===
customized
[
item
])
===
undefined
)
{
flag
=
true
;
parametersIllegal
=
true
;
return
;
return
;
}
}
}
else
{
}
else
{
if
(
customized
[
item
]
<
searchSpace
[
item
].
_value
[
0
]
if
(
customized
[
item
]
<
searchSpace
[
item
].
_value
[
0
]
||
customized
[
item
]
>
searchSpace
[
item
].
_value
[
1
])
{
||
customized
[
item
]
>
searchSpace
[
item
].
_value
[
1
])
{
flag
=
true
;
parametersIllegal
=
true
;
return
;
return
;
}
}
}
}
}
}
});
});
if
(
flag
!==
false
)
{
if
(
parametersIllegal
!==
false
)
{
// open the warning modal
// open the warning modal
this
.
setState
(()
=>
({
isShowWarning
:
true
,
customParameters
:
customized
}));
this
.
setState
(()
=>
({
isShowWarning
:
true
,
customParameters
:
customized
}));
}
else
{
}
else
{
...
@@ -269,4 +274,4 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
...
@@ -269,4 +274,4 @@ class Customize extends React.Component<CustomizeProps, CustomizeState> {
}
}
}
}
export
default
Customize
;
export
default
Customize
;
\ No newline at end of file
src/webui/src/components/Overview.tsx
View file @
75028bd7
This diff is collapsed.
Click to expand it.
src/webui/src/components/overview/SuccessTable.tsx
View file @
75028bd7
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment