Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
aa316742
Unverified
Commit
aa316742
authored
Feb 21, 2020
by
SparkSnail
Committed by
GitHub
Feb 21, 2020
Browse files
Merge pull request #233 from microsoft/master
merge master
parents
3fe117f0
24fa4619
Changes
285
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
823 additions
and
434 deletions
+823
-434
src/sdk/pynni/nni/compression/torch/__init__.py
src/sdk/pynni/nni/compression/torch/__init__.py
+1
-0
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
...i/nni/compression/torch/activation_rank_filter_pruners.py
+15
-10
src/sdk/pynni/nni/compression/torch/apply_compression.py
src/sdk/pynni/nni/compression/torch/apply_compression.py
+70
-0
src/sdk/pynni/nni/compression/torch/compressor.py
src/sdk/pynni/nni/compression/torch/compressor.py
+277
-99
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+68
-72
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+8
-9
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
+2
-0
src/sdk/pynni/nni/medianstop_assessor/test.py
src/sdk/pynni/nni/medianstop_assessor/test.py
+3
-3
src/sdk/pynni/nni/msg_dispatcher.py
src/sdk/pynni/nni/msg_dispatcher.py
+5
-0
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
+29
-3
src/sdk/pynni/nni/nas/pytorch/base_trainer.py
src/sdk/pynni/nni/nas/pytorch/base_trainer.py
+17
-0
src/sdk/pynni/nni/nas/pytorch/callbacks.py
src/sdk/pynni/nni/nas/pytorch/callbacks.py
+65
-0
src/sdk/pynni/nni/nas/pytorch/cdarts/mutator.py
src/sdk/pynni/nni/nas/pytorch/cdarts/mutator.py
+7
-10
src/sdk/pynni/nni/nas/pytorch/cdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/cdarts/trainer.py
+61
-61
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+34
-18
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
+19
-14
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+36
-36
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+29
-29
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+52
-52
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+25
-18
No files found.
src/sdk/pynni/nni/compression/torch/__init__.py
View file @
aa316742
...
...
@@ -6,3 +6,4 @@ from .pruners import *
from
.weight_rank_filter_pruners
import
*
from
.activation_rank_filter_pruners
import
*
from
.quantizers
import
*
from
.apply_compression
import
apply_compression_results
src/sdk/pynni/nni/compression/torch/activation_rank_filter_pruners.py
View file @
aa316742
...
...
@@ -32,7 +32,7 @@ class ActivationRankFilterPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
hooks
=
{}
...
...
@@ -48,22 +48,29 @@ class ActivationRankFilterPruner(Pruner):
"""
Compress the model, register a hook for collecting activations.
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
layer
.
module
.
register_forward_hook
(
_hook
)
wrapper
.
module
.
register_forward_hook
(
_hook
)
self
.
_wrap_model
()
return
self
.
bound_model
def
get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
...
...
@@ -82,14 +89,13 @@ class ActivationRankFilterPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask
_calculated
_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
if
_calculated
=
kwargs
[
"if
_calculated
"
]
if
if_calculated
:
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
...
...
@@ -104,8 +110,7 @@ class ActivationRankFilterPruner(Pruner):
mask
=
self
.
get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
...
...
src/sdk/pynni/nni/compression/torch/apply_compression.py
0 → 100644
View file @
aa316742
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
from
.compressor
import
Pruner
logger
=
logging
.
getLogger
(
'torch apply compression'
)
def
apply_compression_results
(
model
,
masks_file
):
"""
Apply the masks from ```masks_file``` to the model
Parameters
----------
model : torch.nn.module
The model to be compressed
masks_file : str
The path of the mask file
"""
apply_comp
=
ApplyCompression
(
model
,
masks_file
)
apply_comp
.
compress
()
class
ApplyCompression
(
Pruner
):
"""
This class is not to generate masks, but applying existing masks
"""
def
__init__
(
self
,
model
,
masks_file
):
"""
Parameters
----------
model : torch.nn.module
Model to be masked
masks_file : str
The path of user provided mask file
"""
self
.
bound_model
=
model
self
.
masks
=
torch
.
load
(
masks_file
)
for
module_name
in
self
.
masks
:
print
(
'module_name: '
,
module_name
)
config_list
=
self
.
_build_config
()
super
().
__init__
(
model
,
config_list
)
def
_build_config
(
self
):
op_names
=
[]
for
module_name
in
self
.
masks
:
op_names
.
append
(
module_name
)
return
[{
'sparsity'
:
1
,
'op_types'
:
[
'default'
,
'BatchNorm2d'
],
'op_names'
:
op_names
}]
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Directly return the corresponding mask
Parameters
----------
layer : LayerInfo
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
dict
Mask of the layer
"""
assert
layer
.
name
in
self
.
masks
return
self
.
masks
[
layer
.
name
]
src/sdk/pynni/nni/compression/torch/compressor.py
View file @
aa316742
...
...
@@ -14,8 +14,11 @@ class LayerInfo:
self
.
name
=
name
self
.
type
=
type
(
module
).
__name__
self
.
_forward
=
None
def
_setattr
(
model
,
name
,
module
):
name_list
=
name
.
split
(
"."
)
for
name
in
name_list
[:
-
1
]:
model
=
getattr
(
model
,
name
)
setattr
(
model
,
name_list
[
-
1
],
module
)
class
Compressor
:
"""
...
...
@@ -36,6 +39,9 @@ class Compressor:
self
.
bound_model
=
model
self
.
config_list
=
config_list
self
.
modules_to_compress
=
None
self
.
modules_wrapper
=
None
self
.
buffers
=
{}
self
.
is_wrapped
=
False
def
detect_modules_to_compress
(
self
):
"""
...
...
@@ -51,21 +57,60 @@ class Compressor:
self
.
modules_to_compress
.
append
((
layer
,
config
))
return
self
.
modules_to_compress
def
_wrap_model
(
self
):
"""
wrap all modules that needed to be compressed
"""
for
wrapper
in
reversed
(
self
.
get_modules_wrapper
()):
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
)
self
.
is_wrapped
=
True
def
_unwrap_model
(
self
):
"""
unwrap all modules that needed to be compressed
"""
for
wrapper
in
self
.
get_modules_wrapper
():
_setattr
(
self
.
bound_model
,
wrapper
.
name
,
wrapper
.
module
)
self
.
is_wrapped
=
False
def
compress
(
self
):
"""
Compress the model with algorithm implemented by subclass.
The model will be instrumented and user should never edit it after calling this method.
`self.modules_to_compress` records all the to-be-compressed layers
Returns
-------
torch.nn.Module
model with specified modules compressed.
"""
if
self
.
modules_wrapper
is
not
None
:
# already compressed
return
self
.
bound_model
else
:
self
.
modules_wrapper
=
[]
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
wrapper
=
self
.
_wrap_modules
(
layer
,
config
)
self
.
modules_wrapper
.
append
(
wrapper
)
self
.
_wrap_model
()
return
self
.
bound_model
def
register_buffer
(
self
,
name
,
value
):
"""
To register buffers used in wrapped module's forward method.
"""
self
.
buffers
[
name
]
=
value
def
get_modules_to_compress
(
self
):
"""
To obtain all the to-be-compressed
layer
s.
To obtain all the to-be-compressed
module
s.
Returns
-------
...
...
@@ -75,6 +120,17 @@ class Compressor:
"""
return
self
.
modules_to_compress
def
get_modules_wrapper
(
self
):
"""
To obtain all the wrapped modules.
Returns
-------
list
a list of the wrapped modules
"""
return
self
.
modules_wrapper
def
select_config
(
self
,
layer
):
"""
Find the configuration for `layer` by parsing `self.config_list`
...
...
@@ -93,13 +149,24 @@ class Compressor:
ret
=
None
for
config
in
self
.
config_list
:
config
=
config
.
copy
()
config
[
'op_types'
]
=
self
.
_expand_config_op_types
(
config
)
if
layer
.
type
not
in
config
[
'op_types'
]:
# expand config if key `default` is in config['op_types']
if
'op_types'
in
config
and
'default'
in
config
[
'op_types'
]:
expanded_op_types
=
[]
for
op_type
in
config
[
'op_types'
]:
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
default_layers
.
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
config
[
'op_types'
]
=
expanded_op_types
# check if condition is satisified
if
'op_types'
in
config
and
layer
.
type
not
in
config
[
'op_types'
]:
continue
if
config
.
get
(
'op_names'
)
and
layer
.
name
not
in
config
[
'op_names'
]:
if
'op_names'
in
config
and
layer
.
name
not
in
config
[
'op_names'
]:
continue
ret
=
config
if
ret
is
None
or
ret
.
get
(
'exclude'
)
:
if
ret
is
None
or
'exclude'
in
ret
:
return
None
return
ret
...
...
@@ -119,7 +186,7 @@ class Compressor:
If user want to update model every step, user can override this method
"""
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer`
...
...
@@ -132,17 +199,66 @@ class Compressor:
"""
raise
NotImplementedError
()
def
_expand_config_op_types
(
self
,
config
):
if
config
is
None
:
return
[]
expanded_op_types
=
[]
for
op_type
in
config
.
get
(
'op_types'
,
[]):
if
op_type
==
'default'
:
expanded_op_types
.
extend
(
default_layers
.
weighted_modules
)
else
:
expanded_op_types
.
append
(
op_type
)
return
expanded_op_types
class
PrunerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
pruner
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
pruner : Pruner
the pruner used to calculate mask
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
self
.
type
=
module_type
# config and pruner
self
.
config
=
config
self
.
pruner
=
pruner
self
.
registered_buffers
=
[]
# register buffer for mask
self
.
register_buffer
(
"weight_mask"
,
torch
.
ones
(
self
.
module
.
weight
.
shape
))
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
self
.
register_buffer
(
"bias_mask"
,
torch
.
ones
(
self
.
module
.
bias
.
shape
))
else
:
self
.
register_buffer
(
"bias_mask"
,
None
)
self
.
registered_buffers
.
append
(
'weight_mask'
)
self
.
registered_buffers
.
append
(
'bias_mask'
)
# register user specified buffer
for
name
in
self
.
pruner
.
buffers
:
self
.
register_buffer
(
name
,
self
.
pruner
.
buffers
[
name
].
clone
())
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
mask
=
self
.
pruner
.
calc_mask
(
LayerInfo
(
self
.
name
,
self
.
module
),
self
.
config
,
**
self
.
get_registered_buffers
())
if
mask
is
not
None
:
self
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# apply mask to weight
self
.
module
.
weight
.
data
=
self
.
module
.
weight
.
data
.
mul_
(
self
.
weight_mask
)
# apply mask to bias
if
hasattr
(
self
.
module
,
'bias'
)
and
self
.
module
.
bias
is
not
None
:
if
mask
is
not
None
and
'bias'
in
mask
:
self
.
bias_mask
.
copy_
(
mask
[
'bias'
])
self
.
module
.
bias
.
data
=
self
.
module
.
bias
.
data
.
mul_
(
self
.
bias_mask
)
return
self
.
module
(
*
inputs
)
class
Pruner
(
Compressor
):
"""
...
...
@@ -158,9 +274,8 @@ class Pruner(Compressor):
def
__init__
(
self
,
model
,
config_list
):
super
().
__init__
(
model
,
config_list
)
self
.
mask_dict
=
{}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Pruners should overload this method to provide mask for weight tensors.
The mask must have the same shape and type comparing to the weight.
...
...
@@ -176,9 +291,9 @@ class Pruner(Compressor):
"""
raise
NotImplementedError
(
"Pruners must overload calc_mask()"
)
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
Create a wrapper
forward function
to replace the original one.
Create a wrapper
module
to replace the original one.
Parameters
----------
...
...
@@ -187,30 +302,14 @@ class Pruner(Compressor):
config : dict
the configuration for generating the mask
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
return
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
mask
=
self
.
calc_mask
(
layer
,
config
)
# apply mask to weight
old_weight
=
layer
.
module
.
weight
.
data
mask_weight
=
mask
[
'weight'
]
layer
.
module
.
weight
.
data
=
old_weight
.
mul
(
mask_weight
)
# apply mask to bias
if
mask
.
__contains__
(
'bias'
)
and
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
old_bias
=
layer
.
module
.
bias
.
data
mask_bias
=
mask
[
'bias'
]
layer
.
module
.
bias
.
data
=
old_bias
.
mul
(
mask_bias
)
# calculate forward
ret
=
layer
.
_forward
(
*
inputs
)
return
ret
layer
.
module
.
forward
=
new_forward
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
):
_logger
.
info
(
"compressing module %s."
,
layer
.
name
)
wrapper
=
PrunerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
assert
hasattr
(
layer
.
module
,
'weight'
),
"module %s does not have 'weight' attribute"
%
layer
.
name
# move newly registered buffers to the same device of weight
wrapper
.
to
(
layer
.
module
.
weight
.
device
)
return
wrapper
def
export_model
(
self
,
model_path
,
mask_path
=
None
,
onnx_path
=
None
,
input_shape
=
None
,
device
=
None
):
"""
Export pruned model weights, masks and onnx model(optional)
...
...
@@ -224,35 +323,144 @@ class Pruner(Compressor):
(optional) path to save onnx model
input_shape : list or tuple
input shape to onnx model
device : torch.device
device of the model, used to place the dummy input tensor for exporting onnx file.
the tensor is placed on cpu if ```device``` is None
"""
if
self
.
detect_modules_to_compress
()
and
not
self
.
mask_dict
:
_logger
.
warning
(
'You may not use self.mask_dict in base Pruner class to record masks'
)
#
if self.detect_modules_to_compress() and not self.mask_dict:
#
_logger.warning('You may not use self.mask_dict in base Pruner class to record masks')
assert
model_path
is
not
None
,
'model_path must be specified'
for
name
,
m
in
self
.
bound_model
.
named_modules
():
if
name
==
""
:
continue
masks
=
self
.
mask_dict
.
get
(
name
)
if
masks
is
not
None
:
mask_sum
=
masks
[
'weight'
].
sum
().
item
()
mask_num
=
masks
[
'weight'
].
numel
()
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
name
,
1
-
mask_sum
/
mask_num
)
m
.
weight
.
data
=
m
.
weight
.
data
.
mul
(
masks
[
'weight'
])
if
masks
.
__contains__
(
'bias'
)
and
hasattr
(
m
,
'bias'
)
and
m
.
bias
is
not
None
:
m
.
bias
.
data
=
m
.
bias
.
data
.
mul
(
masks
[
'bias'
])
else
:
_logger
.
info
(
'Layer: %s NOT compressed'
,
name
)
mask_dict
=
{}
self
.
_unwrap_model
()
# used for generating correct state_dict name without wrapper state
for
wrapper
in
self
.
get_modules_wrapper
():
weight_mask
=
wrapper
.
weight_mask
bias_mask
=
wrapper
.
bias_mask
if
weight_mask
is
not
None
:
mask_sum
=
weight_mask
.
sum
().
item
()
mask_num
=
weight_mask
.
numel
()
_logger
.
info
(
'Layer: %s Sparsity: %.2f'
,
wrapper
.
name
,
1
-
mask_sum
/
mask_num
)
wrapper
.
module
.
weight
.
data
=
wrapper
.
module
.
weight
.
data
.
mul
(
weight_mask
)
if
bias_mask
is
not
None
:
wrapper
.
module
.
bias
.
data
=
wrapper
.
module
.
bias
.
data
.
mul
(
bias_mask
)
# save mask to dict
mask_dict
[
wrapper
.
name
]
=
{
"weight"
:
weight_mask
,
"bias"
:
bias_mask
}
torch
.
save
(
self
.
bound_model
.
state_dict
(),
model_path
)
_logger
.
info
(
'Model state_dict saved to %s'
,
model_path
)
if
mask_path
is
not
None
:
torch
.
save
(
self
.
mask_dict
,
mask_path
)
torch
.
save
(
mask_dict
,
mask_path
)
_logger
.
info
(
'Mask dict saved to %s'
,
mask_path
)
if
onnx_path
is
not
None
:
assert
input_shape
is
not
None
,
'input_shape must be specified to export onnx model'
# input info needed
if
device
is
None
:
device
=
torch
.
device
(
'cpu'
)
input_data
=
torch
.
Tensor
(
*
input_shape
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
,
onnx_path
)
torch
.
onnx
.
export
(
self
.
bound_model
,
input_data
.
to
(
device
)
,
onnx_path
)
_logger
.
info
(
'Model in onnx with input shape %s saved to %s'
,
input_data
.
shape
,
onnx_path
)
self
.
_wrap_model
()
def
load_model_state_dict
(
self
,
model_state
):
"""
Load the state dict saved from unwrapped model.
Parameters:
-----------
model_state : dict
state dict saved from unwrapped model
"""
if
self
.
is_wrapped
:
self
.
_unwrap_model
()
self
.
bound_model
.
load_state_dict
(
model_state
)
self
.
_wrap_model
()
else
:
self
.
bound_model
.
load_state_dict
(
model_state
)
class
QuantizerModuleWrapper
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
module
,
module_name
,
module_type
,
config
,
quantizer
):
"""
Wrap an module to enable data parallel, forward method customization and buffer registeration.
Parameters
----------
module : pytorch module
the module user wants to compress
config : dict
the configurations that users specify for compression
module_name : str
the name of the module to compress, wrapper module shares same name
module_type : str
the type of the module to compress
quantizer :quantizer
the quantizer used to calculate mask
"""
super
().
__init__
()
# origin layer information
self
.
module
=
module
self
.
name
=
module_name
self
.
type
=
module_type
# config and pruner
self
.
config
=
config
self
.
quantizer
=
quantizer
self
.
registered_buffers
=
[]
# register buffer and parameter
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
if
'weight'
in
config
[
'quant_types'
]:
if
not
_check_weight
(
self
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
self
.
name
)
else
:
self
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
self
.
module
.
weight
))
delattr
(
self
.
module
,
'weight'
)
self
.
module
.
register_buffer
(
'weight'
,
self
.
module
.
old_weight
)
# register user specified buffer
for
name
in
self
.
quantizer
.
buffers
:
self
.
register_buffer
(
name
,
self
.
quantizer
.
buffers
[
name
].
clone
())
self
.
registered_buffers
.
append
(
name
)
def
get_registered_buffers
(
self
):
buffers
=
{}
for
name
in
self
.
registered_buffers
:
buffers
[
name
]
=
getattr
(
self
,
name
)
return
buffers
def
forward
(
self
,
*
inputs
):
if
'input'
in
self
.
config
[
'quant_types'
]:
inputs
=
self
.
quantizer
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantizer
.
quantize_input
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
get_registered_buffers
())
if
'weight'
in
self
.
config
[
'quant_types'
]
and
_check_weight
(
self
.
module
):
new_weight
=
self
.
quantizer
.
quant_grad
.
apply
(
self
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantizer
.
quantize_weight
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
get_registered_buffers
())
self
.
module
.
weight
=
new_weight
result
=
self
.
module
(
*
inputs
)
else
:
result
=
self
.
module
(
*
inputs
)
if
'output'
in
self
.
config
[
'quant_types'
]:
result
=
self
.
quantizer
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantizer
.
quantize_output
,
self
.
config
,
LayerInfo
(
self
.
name
,
self
.
module
),
**
self
.
get_registered_buffers
())
return
result
class
Quantizer
(
Compressor
):
"""
...
...
@@ -303,7 +511,7 @@ class Quantizer(Compressor):
raise
NotImplementedError
(
'Quantizer must overload quantize_input()'
)
def
_
instrument_layer
(
self
,
layer
,
config
):
def
_
wrap_modules
(
self
,
layer
,
config
):
"""
Create a wrapper forward function to replace the original one.
Parameters
...
...
@@ -313,7 +521,6 @@ class Quantizer(Compressor):
config : dict
the configuration for quantization
"""
assert
layer
.
_forward
is
None
,
'Each model can only be compressed once'
assert
'quant_types'
in
config
,
'must provide quant_types in config'
assert
isinstance
(
config
[
'quant_types'
],
list
),
'quant_types must be list type'
assert
'quant_bits'
in
config
,
'must provide quant_bits in config'
...
...
@@ -323,35 +530,7 @@ class Quantizer(Compressor):
for
quant_type
in
config
[
'quant_types'
]:
assert
quant_type
in
config
[
'quant_bits'
],
'bits length for %s must be specified in quant_bits dict'
%
quant_type
if
'weight'
in
config
[
'quant_types'
]:
if
not
_check_weight
(
layer
.
module
):
_logger
.
warning
(
'Module %s does not have parameter "weight"'
,
layer
.
name
)
else
:
# old_weight is used to store origin weight and weight is used to store quantized weight
# the reason why weight is buffer instead of parameter is because in pytorch parameter is used as leaf
# if weight is leaf , then old_weight can not be updated.
layer
.
module
.
register_parameter
(
'old_weight'
,
torch
.
nn
.
Parameter
(
layer
.
module
.
weight
))
delattr
(
layer
.
module
,
'weight'
)
layer
.
module
.
register_buffer
(
'weight'
,
layer
.
module
.
old_weight
)
layer
.
_forward
=
layer
.
module
.
forward
def
new_forward
(
*
inputs
):
if
'input'
in
config
[
'quant_types'
]:
inputs
=
self
.
quant_grad
.
apply
(
inputs
,
QuantType
.
QUANT_INPUT
,
self
.
quantize_input
,
config
,
layer
)
if
'weight'
in
config
[
'quant_types'
]
and
_check_weight
(
layer
.
module
):
new_weight
=
self
.
quant_grad
.
apply
(
layer
.
module
.
old_weight
,
QuantType
.
QUANT_WEIGHT
,
self
.
quantize_weight
,
config
,
layer
)
layer
.
module
.
weight
=
new_weight
result
=
layer
.
_forward
(
*
inputs
)
else
:
result
=
layer
.
_forward
(
*
inputs
)
if
'output'
in
config
[
'quant_types'
]:
result
=
self
.
quant_grad
.
apply
(
result
,
QuantType
.
QUANT_OUTPUT
,
self
.
quantize_output
,
config
,
layer
)
return
result
layer
.
module
.
forward
=
new_forward
return
QuantizerModuleWrapper
(
layer
.
module
,
layer
.
name
,
layer
.
type
,
config
,
self
)
class
QuantType
:
"""
...
...
@@ -387,19 +566,18 @@ class QuantGrad(torch.autograd.Function):
return
grad_output
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
):
def
forward
(
ctx
,
tensor
,
quant_type
,
quant_func
,
config
,
layer
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
)
return
quant_func
(
tensor
,
config
,
op
=
layer
.
module
,
op_type
=
layer
.
type
,
op_name
=
layer
.
name
,
**
kwargs
)
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
return
output
,
None
,
None
,
None
,
None
return
output
,
None
,
None
,
None
,
None
,
None
def
_check_weight
(
module
):
try
:
return
isinstance
(
module
.
weight
.
data
,
torch
.
Tensor
)
except
AttributeError
:
return
False
\ No newline at end of file
src/sdk/pynni/nni/compression/torch/pruners.py
View file @
aa316742
...
...
@@ -27,9 +27,9 @@ class LevelPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer
Parameters
...
...
@@ -45,8 +45,9 @@ class LevelPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if
op_name
not
in
self
.
mask_calculated_ops
:
if_calculated
=
kwargs
[
"if_calculated"
]
if
not
if_calculated
:
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
...
...
@@ -54,12 +55,10 @@ class LevelPruner(Pruner):
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
mask
=
self
.
mask_dict
[
op_name
]
return
mask
return
None
class
AGP_Pruner
(
Pruner
):
...
...
@@ -84,17 +83,20 @@ class AGP_Pruner(Pruner):
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
...
...
@@ -102,24 +104,26 @@ class AGP_Pruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
[
'weight'
]
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
if_calculated
=
kwargs
[
"if_calculated"
]
if
if_calculated
:
return
None
if
not
(
self
.
now_epoch
>=
start_epoch
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
):
return
None
mask
=
{
'weight'
:
kwargs
[
'weight_mask'
]
if
'weight_mask'
in
kwargs
else
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)}
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
[
'weight'
]
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
...
...
@@ -165,9 +169,8 @@ class AGP_Pruner(Pruner):
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
self
.
if_init_list
[
k
]
=
True
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
if_calculated
.
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
class
SlimPruner
(
Pruner
):
"""
...
...
@@ -187,7 +190,6 @@ class SlimPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
...
...
@@ -198,8 +200,9 @@ class SlimPruner(Pruner):
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
...
...
@@ -209,6 +212,8 @@ class SlimPruner(Pruner):
the layer to instrument the compression operation
config : dict
layer's pruning config
kwargs: dict
buffers registered in __init__ function
Returns
-------
dict
...
...
@@ -216,27 +221,21 @@ class SlimPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
if_calculated
=
kwargs
[
"if_calculated"
]
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
if
if_calculated
:
return
None
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
>=
2
and
num_prune
>=
1
:
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
class
LotteryTicketPruner
(
Pruner
):
...
...
@@ -294,38 +293,23 @@ class LotteryTicketPruner(Pruner):
prune_iterations
=
config
[
'prune_iterations'
]
return
prune_iterations
def
_print_masks
(
self
,
print_mask
=
False
):
torch
.
set_printoptions
(
threshold
=
1000
)
for
op_name
in
self
.
mask_dict
.
keys
():
mask
=
self
.
mask_dict
[
op_name
]
print
(
'op name: '
,
op_name
)
if
print_mask
:
print
(
'mask: '
,
mask
)
# calculate current sparsity
mask_num
=
mask
[
'weight'
].
sum
().
item
()
mask_size
=
mask
[
'weight'
].
numel
()
print
(
'sparsity: '
,
1
-
mask_num
/
mask_size
)
torch
.
set_printoptions
(
profile
=
'default'
)
def
_calc_sparsity
(
self
,
sparsity
):
keep_ratio_once
=
(
1
-
sparsity
)
**
(
1
/
self
.
prune_iterations
)
curr_keep_ratio
=
keep_ratio_once
**
self
.
curr_prune_iteration
return
max
(
1
-
curr_keep_ratio
,
0
)
def
_calc_mask
(
self
,
weight
,
sparsity
,
op_name
):
def
_calc_mask
(
self
,
weight
,
sparsity
,
curr_w_mask
):
if
self
.
curr_prune_iteration
==
0
:
mask
=
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
else
:
curr_sparsity
=
self
.
_calc_sparsity
(
sparsity
)
assert
self
.
mask_dict
.
get
(
op_name
)
is
not
None
curr_mask
=
self
.
mask_dict
.
get
(
op_name
)
w_abs
=
weight
.
abs
()
*
curr_mask
[
'weight'
]
w_abs
=
weight
.
abs
()
*
curr_w_mask
k
=
int
(
w_abs
.
numel
()
*
curr_sparsity
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
).
values
.
max
()
mask
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
return
{
'weight'
:
mask
}
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Generate mask for the given ``weight``.
...
...
@@ -335,15 +319,17 @@ class LotteryTicketPruner(Pruner):
The layer to be pruned
config : dict
Pruning configurations for this weight
kwargs : dict
Auxiliary information
Returns
-------
tensor
The mask for this weight
The mask for this weight, it is ```None``` because this pruner
calculates and assigns masks in ```prune_iteration_start```,
no need to do anything in this function.
"""
assert
self
.
mask_dict
.
get
(
layer
.
name
)
is
not
None
,
'Please call iteration_start before training'
mask
=
self
.
mask_dict
[
layer
.
name
]
return
mask
return
None
def
get_prune_iterations
(
self
):
"""
...
...
@@ -368,16 +354,26 @@ class LotteryTicketPruner(Pruner):
self
.
curr_prune_iteration
+=
1
assert
self
.
curr_prune_iteration
<
self
.
prune_iterations
+
1
,
'Exceed the configured prune_iterations'
modules_wrapper
=
self
.
get_modules_wrapper
()
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
module_wrapper
=
None
for
wrapper
in
modules_wrapper
:
if
wrapper
.
name
==
layer
.
name
:
module_wrapper
=
wrapper
break
assert
module_wrapper
is
not
None
sparsity
=
config
.
get
(
'sparsity'
)
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
layer
.
name
)
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
_print_masks
()
mask
=
self
.
_calc_mask
(
layer
.
module
.
weight
.
data
,
sparsity
,
module_wrapper
.
weight_mask
)
# TODO: directly use weight_mask is not good
module_wrapper
.
weight_mask
.
copy_
(
mask
[
'weight'
])
# there is no mask for bias
# reinit weights back to original after new masks are generated
if
self
.
reset_weights
:
self
.
_model
.
load_state_dict
(
self
.
_model_state
)
# should use this member function to reset model weights
self
.
load_model_state_dict
(
self
.
_model_state
)
self
.
_optimizer
.
load_state_dict
(
self
.
_optimizer_state
)
if
self
.
_lr_scheduler
is
not
None
:
self
.
_lr_scheduler
.
load_state_dict
(
self
.
_scheduler_state
)
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
View file @
aa316742
...
...
@@ -27,12 +27,12 @@ class WeightRankFilterPruner(Pruner):
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
# operations whose mask has been calculated
self
.
register_buffer
(
"if_calculated"
,
torch
.
tensor
(
0
))
# pylint: disable=not-callable
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
raise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
,
**
kwargs
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion of the kernel weights are masked.
...
...
@@ -49,14 +49,13 @@ class WeightRankFilterPruner(Pruner):
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv1d'
,
'Conv2d'
],
"only support Conv1d and Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask
_calculated
_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
if
_calculated
=
kwargs
[
"if
_calculated
"
]
if
if_calculated
:
return
None
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
...
...
@@ -70,8 +69,7 @@ class WeightRankFilterPruner(Pruner):
return
mask
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
if_calculated
.
copy_
(
torch
.
tensor
(
1
))
# pylint: disable=not-callable
return
mask
...
...
@@ -259,4 +257,5 @@ class FPGMPruner(WeightRankFilterPruner):
return
x
.
sum
()
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
for
wrapper
in
self
.
get_modules_wrapper
():
wrapper
.
registered_buffers
[
'if_calculated'
].
copy_
(
torch
.
tensor
(
0
))
# pylint: disable=not-callable
src/sdk/pynni/nni/hyperband_advisor/hyperband_advisor.py
View file @
aa316742
...
...
@@ -380,6 +380,8 @@ class Hyperband(MsgDispatcherBase):
ValueError
Data type not supported
"""
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
loads
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
...
...
src/sdk/pynni/nni/medianstop_assessor/test.py
View file @
aa316742
...
...
@@ -31,11 +31,11 @@ def test():
# [1,1,1,1,1,1,1,1,1,1],
# [1,1,1,1,1,1,1,1,1,1]]
assessor
=
MedianstopAssessor
(
FLAGS
.
start_step
,
FLAGS
.
optimize_mode
)
for
i
in
range
(
4
):
assessor
=
MedianstopAssessor
(
FLAGS
.
optimize_mode
,
FLAGS
.
start_step
)
for
i
in
range
(
len
(
lcs
)
):
#lc = []
to_complete
=
True
for
k
in
range
(
10
):
for
k
in
range
(
len
(
lcs
[
0
])
):
#d = random.randint(i*100+0, i*100+100)
#lc.append(d)
ret
=
assessor
.
assess_trial
(
i
,
lcs
[
i
][:
k
+
1
])
...
...
src/sdk/pynni/nni/msg_dispatcher.py
View file @
aa316742
...
...
@@ -113,6 +113,8 @@ class MsgDispatcher(MsgDispatcherBase):
"""Import additional data for tuning
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for
entry
in
data
:
entry
[
'value'
]
=
json_tricks
.
loads
(
entry
[
'value'
])
self
.
tuner
.
import_data
(
data
)
def
handle_add_customized_trial
(
self
,
data
):
...
...
@@ -127,6 +129,9 @@ class MsgDispatcher(MsgDispatcherBase):
- 'value': metric value reported by nni.report_final_result()
- 'type': report type, support {'FINAL', 'PERIODICAL'}
"""
# metrics value is dumped as json string in trial, so we need to decode it here
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
loads
(
data
[
'value'
])
if
data
[
'type'
]
==
MetricType
.
FINAL
:
self
.
_handle_final_metric_data
(
data
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
View file @
aa316742
...
...
@@ -13,7 +13,12 @@ logger = logging.getLogger(__name__)
class
BaseMutator
(
nn
.
Module
):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables.
callbacks that are called in ``forward`` in mutables.
Parameters
----------
model : nn.Module
PyTorch model to apply mutator on.
"""
def
__init__
(
self
,
model
):
...
...
@@ -52,9 +57,23 @@ class BaseMutator(nn.Module):
@
property
def
mutables
(
self
):
"""
A generator of all modules inheriting :class:`~nni.nas.pytorch.mutables.Mutable`.
Modules are yielded in the order that they are defined in ``__init__``.
For mutables with their keys appearing multiple times, only the first one will appear.
"""
return
self
.
_structured_mutables
@
property
def
undedup_mutables
(
self
):
return
self
.
_structured_mutables
.
traverse
(
deduplicate
=
False
)
def
forward
(
self
,
*
inputs
):
"""
Warnings
--------
Don't call forward of a mutator.
"""
raise
RuntimeError
(
"Forward is undefined for mutators."
)
def
__setattr__
(
self
,
name
,
value
):
...
...
@@ -70,6 +89,7 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable_scope : MutableScope
The mutable scope that is entered.
"""
pass
...
...
@@ -80,6 +100,7 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable_scope : MutableScope
The mutable scope that is exited.
"""
pass
...
...
@@ -90,12 +111,14 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable : LayerChoice
Module whose forward is called.
inputs : list of torch.Tensor
The arguments of its forward function.
Returns
-------
tuple of torch.Tensor and torch.Tensor
o
utput tensor and mask
O
utput tensor and mask
.
"""
raise
NotImplementedError
...
...
@@ -106,12 +129,14 @@ class BaseMutator(nn.Module):
Parameters
----------
mutable : InputChoice
Mutable that is called.
tensor_list : list of torch.Tensor
The arguments mutable is called with.
Returns
-------
tuple of torch.Tensor and torch.Tensor
o
utput tensor and mask
O
utput tensor and mask
.
"""
raise
NotImplementedError
...
...
@@ -123,5 +148,6 @@ class BaseMutator(nn.Module):
Returns
-------
dict
Mappings from mutable keys to decisions.
"""
raise
NotImplementedError
src/sdk/pynni/nni/nas/pytorch/base_trainer.py
View file @
aa316742
...
...
@@ -8,16 +8,33 @@ class BaseTrainer(ABC):
@
abstractmethod
def
train
(
self
):
"""
Override the method to train.
"""
raise
NotImplementedError
@
abstractmethod
def
validate
(
self
):
"""
Override the method to validate.
"""
raise
NotImplementedError
@
abstractmethod
def
export
(
self
,
file
):
"""
Override the method to export to file.
Parameters
----------
file : str
File path to export to.
"""
raise
NotImplementedError
@
abstractmethod
def
checkpoint
(
self
):
"""
Override to dump a checkpoint.
"""
raise
NotImplementedError
src/sdk/pynni/nni/nas/pytorch/callbacks.py
View file @
aa316742
...
...
@@ -11,6 +11,9 @@ _logger = logging.getLogger(__name__)
class
Callback
:
"""
Callback provides an easy way to react to events like begin/end of epochs.
"""
def
__init__
(
self
):
self
.
model
=
None
...
...
@@ -18,14 +21,42 @@ class Callback:
self
.
trainer
=
None
def
build
(
self
,
model
,
mutator
,
trainer
):
"""
Callback needs to be built with model, mutator, trainer, to get updates from them.
Parameters
----------
model : nn.Module
Model to be trained.
mutator : nn.Module
Mutator that mutates the model.
trainer : BaseTrainer
Trainer that is to call the callback.
"""
self
.
model
=
model
self
.
mutator
=
mutator
self
.
trainer
=
trainer
def
on_epoch_begin
(
self
,
epoch
):
"""
Implement this to do something at the begin of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass
def
on_epoch_end
(
self
,
epoch
):
"""
Implement this to do something at the end of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass
def
on_batch_begin
(
self
,
epoch
):
...
...
@@ -36,6 +67,14 @@ class Callback:
class
LRSchedulerCallback
(
Callback
):
"""
Calls scheduler on every epoch ends.
Parameters
----------
scheduler : LRScheduler
Scheduler to be called.
"""
def
__init__
(
self
,
scheduler
,
mode
=
"epoch"
):
super
().
__init__
()
assert
mode
==
"epoch"
...
...
@@ -43,28 +82,54 @@ class LRSchedulerCallback(Callback):
self
.
mode
=
mode
def
on_epoch_end
(
self
,
epoch
):
"""
Call ``self.scheduler.step()`` on epoch end.
"""
self
.
scheduler
.
step
()
class
ArchitectureCheckpoint
(
Callback
):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def
__init__
(
self
,
checkpoint_dir
):
super
().
__init__
()
self
.
checkpoint_dir
=
checkpoint_dir
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
def
on_epoch_end
(
self
,
epoch
):
"""
Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end.
"""
dest_path
=
os
.
path
.
join
(
self
.
checkpoint_dir
,
"epoch_{}.json"
.
format
(
epoch
))
_logger
.
info
(
"Saving architecture to %s"
,
dest_path
)
self
.
trainer
.
export
(
dest_path
)
class
ModelCheckpoint
(
Callback
):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def
__init__
(
self
,
checkpoint_dir
):
super
().
__init__
()
self
.
checkpoint_dir
=
checkpoint_dir
os
.
makedirs
(
self
.
checkpoint_dir
,
exist_ok
=
True
)
def
on_epoch_end
(
self
,
epoch
):
"""
Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end.
``DataParallel`` object will have their inside modules exported.
"""
if
isinstance
(
self
.
model
,
nn
.
DataParallel
):
state_dict
=
self
.
model
.
module
.
state_dict
()
else
:
...
...
src/sdk/pynni/nni/nas/pytorch/cdarts/mutator.py
View file @
aa316742
...
...
@@ -127,18 +127,15 @@ class RegularizedMutatorParallel(DistributedDataParallel):
class
DartsDiscreteMutator
(
Mutator
):
"""
A mutator that applies the final sampling result of a parent mutator on another model to train.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
def
__init__
(
self
,
model
,
parent_mutator
):
"""
Initialization.
Parameters
----------
model : nn.Module
The model to apply the mutator.
parent_mutator : Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
super
().
__init__
(
model
)
self
.
__dict__
[
"parent_mutator"
]
=
parent_mutator
# avoid parameters to be included
...
...
src/sdk/pynni/nni/nas/pytorch/cdarts/trainer.py
View file @
aa316742
...
...
@@ -32,73 +32,73 @@ class InteractiveKLLoss(nn.Module):
class
CdartsTrainer
(
object
):
"""
CDARTS trainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
def
__init__
(
self
,
model_small
,
model_large
,
criterion
,
loaders
,
samplers
,
logger
=
None
,
regular_coeff
=
5
,
regular_ratio
=
0.2
,
warmup_epochs
=
2
,
fix_head
=
True
,
epochs
=
32
,
steps_per_epoch
=
None
,
loss_alpha
=
2
,
loss_T
=
2
,
distributed
=
True
,
log_frequency
=
10
,
grad_clip
=
5.0
,
interactive_type
=
'kl'
,
output_path
=
'./outputs'
,
w_lr
=
0.2
,
w_momentum
=
0.9
,
w_weight_decay
=
3e-4
,
alpha_lr
=
0.2
,
alpha_weight_decay
=
1e-4
,
nasnet_lr
=
0.2
,
local_rank
=
0
,
share_module
=
True
):
"""
Initialize a CdartsTrainer.
Parameters
----------
model_small : nn.Module
PyTorch model to be trained. This is the search network of CDARTS.
model_large : nn.Module
PyTorch model to be trained. This is the evaluation network of CDARTS.
criterion : callable
Receives logits and ground truth label, return a loss tensor, e.g., ``nn.CrossEntropyLoss()``.
loaders : list of torch.utils.data.DataLoader
List of train data and valid data loaders, for training weights and architecture weights respectively.
samplers : list of torch.utils.data.Sampler
List of train data and valid data samplers. This can be PyTorch standard samplers if not distributed.
In distributed mode, sampler needs to have ``set_epoch`` method. Refer to data utils in CDARTS example for details.
logger : logging.Logger
The logger for logging. Will use nni logger by default (if logger is ``None``).
regular_coeff : float
The coefficient of regular loss.
regular_ratio : float
The ratio of regular loss.
warmup_epochs : int
The epochs to warmup the search network
fix_head : bool
``True`` if fixing the paramters of auxiliary heads, else unfix the paramters of auxiliary heads.
epochs : int
Number of epochs planned for training.
steps_per_epoch : int
Steps of one epoch.
loss_alpha : float
The loss coefficient.
loss_T : float
The loss coefficient.
distributed : bool
``True`` if using distributed training, else non-distributed training.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping for weights.
interactive_type : string
``kl`` or ``smoothl1``.
output_path : string
Log storage path.
w_lr : float
Learning rate of the search network parameters.
w_momentum : float
Momentum of the search and the evaluation network.
w_weight_decay : float
The weight decay the search and the evaluation network parameters.
alpha_lr : float
Learning rate of the architecture parameters.
alpha_weight_decay : float
The weight decay the architecture parameters.
nasnet_lr : float
Learning rate of the evaluation network parameters.
local_rank : int
The number of thread.
share_module : bool
``True`` if sharing the stem and auxiliary heads, else not sharing these modules.
"""
if
logger
is
None
:
logger
=
logging
.
getLogger
(
__name__
)
train_loader
,
valid_loader
=
loaders
...
...
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
aa316742
...
...
@@ -22,12 +22,21 @@ INPUT_CHOICE = "input_choice"
def
get_and_apply_next_architecture
(
model
):
"""
Wrapper of ClassicMutator to make it more meaningful,
similar to ```get_next_parameter``` for HPO.
Wrapper of :class:`~nni.nas.pytorch.classic_nas.mutator.ClassicMutator` to make it more meaningful,
similar to ``get_next_parameter`` for HPO.
Tt will generate search space based on ``model``.
If env ``NNI_GEN_SEARCH_SPACE`` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ``nnictl`` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model :
pytorch model
u
ser's model with search space (e.g., LayerChoice, InputChoice) embedded in it
model :
nn.Module
U
ser's model with search space (e.g., LayerChoice, InputChoice) embedded in it
.
"""
ClassicMutator
(
model
)
...
...
@@ -36,23 +45,15 @@ class ClassicMutator(Mutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones
to only activate the chosen ones.
Parameters
----------
model : nn.Module
User's model with search space (e.g., LayerChoice, InputChoice) embedded in it.
"""
def
__init__
(
self
,
model
):
"""
Generate search space based on ```model```.
If env ```NNI_GEN_SEARCH_SPACE``` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ```nnictl``` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : PyTorch model
user's model with search space (e.g., LayerChoice, InputChoice) embedded in it
"""
super
(
ClassicMutator
,
self
).
__init__
(
model
)
self
.
_chosen_arch
=
{}
self
.
_search_space
=
self
.
_generate_search_space
()
...
...
@@ -67,6 +68,13 @@ class ClassicMutator(Mutator):
else
:
# get chosen arch from tuner
self
.
_chosen_arch
=
nni
.
get_next_parameter
()
if
self
.
_chosen_arch
is
None
:
if
trial_env_vars
.
NNI_PLATFORM
==
"unittest"
:
# happens if NNI_PLATFORM is intentionally set, e.g., in UT
logger
.
warning
(
"`NNI_PLATFORM` is set but `param` is None. Falling back to standalone mode."
)
self
.
_chosen_arch
=
self
.
_standalone_generate_chosen
()
else
:
raise
RuntimeError
(
"Chosen architecture is None. This may be a platform error."
)
self
.
reset
()
def
_sample_layer_choice
(
self
,
mutable
,
idx
,
value
,
search_space_item
):
...
...
@@ -114,9 +122,15 @@ class ClassicMutator(Mutator):
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
def
sample_search
(
self
):
"""
See :meth:`sample_final`.
"""
return
self
.
sample_final
()
def
sample_final
(
self
):
"""
Convert the chosen arch and apply it on model.
"""
assert
set
(
self
.
_chosen_arch
.
keys
())
==
set
(
self
.
_search_space
.
keys
()),
\
"Unmatched keys, expected keys '{}' from search space, found '{}'."
.
format
(
self
.
_search_space
.
keys
(),
self
.
_chosen_arch
.
keys
())
...
...
@@ -162,6 +176,8 @@ class ClassicMutator(Mutator):
elif
val
[
"_type"
]
==
INPUT_CHOICE
:
choices
=
val
[
"_value"
][
"candidates"
]
n_chosen
=
val
[
"_value"
][
"n_chosen"
]
if
n_chosen
is
None
:
n_chosen
=
len
(
choices
)
chosen_arch
[
key
]
=
{
"_value"
:
choices
[:
n_chosen
],
"_idx"
:
list
(
range
(
n_chosen
))}
else
:
raise
ValueError
(
"Unknown key '%s' and value '%s'."
%
(
key
,
val
))
...
...
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
View file @
aa316742
...
...
@@ -63,18 +63,23 @@ class DartsMutator(Mutator):
edges_max
[
mutable
.
key
]
=
max_val
result
[
mutable
.
key
]
=
F
.
one_hot
(
index
,
num_classes
=
mutable
.
length
).
view
(
-
1
).
bool
()
for
mutable
in
self
.
mutables
:
if
isinstance
(
mutable
,
InputChoice
)
and
mutable
.
n_chosen
is
not
None
:
weights
=
[]
for
src_key
in
mutable
.
choose_from
:
if
src_key
not
in
edges_max
:
_logger
.
warning
(
"InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs."
,
mutable
.
key
)
weights
.
append
(
edges_max
.
get
(
src_key
,
0.
))
weights
=
torch
.
tensor
(
weights
)
# pylint: disable=not-callable
_
,
topk_edge_indices
=
torch
.
topk
(
weights
,
mutable
.
n_chosen
)
selected_multihot
=
[]
for
i
,
src_key
in
enumerate
(
mutable
.
choose_from
):
if
i
not
in
topk_edge_indices
and
src_key
in
result
:
result
[
src_key
]
=
torch
.
zeros_like
(
result
[
src_key
])
# clear this choice to optimize calc graph
selected_multihot
.
append
(
i
in
topk_edge_indices
)
result
[
mutable
.
key
]
=
torch
.
tensor
(
selected_multihot
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
if
isinstance
(
mutable
,
InputChoice
):
if
mutable
.
n_chosen
is
not
None
:
weights
=
[]
for
src_key
in
mutable
.
choose_from
:
if
src_key
not
in
edges_max
:
_logger
.
warning
(
"InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs."
,
mutable
.
key
)
weights
.
append
(
edges_max
.
get
(
src_key
,
0.
))
weights
=
torch
.
tensor
(
weights
)
# pylint: disable=not-callable
_
,
topk_edge_indices
=
torch
.
topk
(
weights
,
mutable
.
n_chosen
)
selected_multihot
=
[]
for
i
,
src_key
in
enumerate
(
mutable
.
choose_from
):
if
i
not
in
topk_edge_indices
and
src_key
in
result
:
# If an edge is never selected, there is no need to calculate any op on this edge.
# This is to eliminate redundant calculation.
result
[
src_key
]
=
torch
.
zeros_like
(
result
[
src_key
])
selected_multihot
.
append
(
i
in
topk_edge_indices
)
result
[
mutable
.
key
]
=
torch
.
tensor
(
selected_multihot
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
else
:
result
[
mutable
.
key
]
=
torch
.
ones
(
mutable
.
n_candidates
,
dtype
=
torch
.
bool
,
device
=
self
.
device
())
# pylint: disable=not-callable
return
result
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
aa316742
...
...
@@ -15,46 +15,46 @@ logger = logging.getLogger(__name__)
class
DartsTrainer
(
Trainer
):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def
__init__
(
self
,
model
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
"""
Initialize a DartsTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
...
...
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
aa316742
...
...
@@ -28,38 +28,38 @@ class StackedLSTMCell(nn.Module):
class
EnasMutator
(
Mutator
):
"""
A mutator that mutates the graph with RL.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
skip_target
=
0.4
,
temperature
=
None
,
branch_bias
=
0.25
,
entropy_reduction
=
"sum"
):
"""
Initialize a EnasMutator.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
super
().
__init__
(
model
)
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
...
...
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
aa316742
...
...
@@ -16,64 +16,64 @@ logger = logging.getLogger(__name__)
class
EnasTrainer
(
Trainer
):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
child_steps
=
500
,
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
,
test_arc_per_epoch
=
1
):
"""
Initialize an EnasTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
...
...
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
aa316742
...
...
@@ -10,20 +10,20 @@ from nni.nas.pytorch.mutator import Mutator
class
FixedArchitecture
(
Mutator
):
"""
Fixed architecture mutator that always selects a certain graph.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
"""
def
__init__
(
self
,
model
,
fixed_arc
,
strict
=
True
):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict : bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super
().
__init__
(
model
)
self
.
_fixed_arc
=
fixed_arc
...
...
@@ -35,9 +35,15 @@ class FixedArchitecture(Mutator):
raise
RuntimeError
(
"Missing keys in fixed architecture: {}."
.
format
(
mutable_keys
-
fixed_arc_keys
))
def
sample_search
(
self
):
"""
Always returns the fixed architecture.
"""
return
self
.
_fixed_arc
def
sample_final
(
self
):
"""
Always returns the fixed architecture.
"""
return
self
.
_fixed_arc
...
...
@@ -52,24 +58,25 @@ def _encode_tensor(data):
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc
_path
):
def
apply_fixed_architecture
(
model
,
fixed_arc
):
"""
Load architecture from `fixed_arc
_path
` and apply to model.
Load architecture from `fixed_arc` and apply to model.
Parameters
----------
model : torch.nn.Module
Model with mutables.
fixed_arc
_path
: str
Path to the JSON that stores the architecture.
fixed_arc : str
or dict
Path to the JSON that stores the
architecture, or dict that stores the exported
architecture.
Returns
-------
FixedArchitecture
Mutator that is responsible for fixes the graph.
"""
if
isinstance
(
fixed_arc
_path
,
str
):
with
open
(
fixed_arc
_path
,
"r"
)
as
f
:
if
isinstance
(
fixed_arc
,
str
):
with
open
(
fixed_arc
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
...
...
Prev
1
…
4
5
6
7
8
9
10
11
12
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment