Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
1a5c0172
Unverified
Commit
1a5c0172
authored
Jan 06, 2020
by
SparkSnail
Committed by
GitHub
Jan 06, 2020
Browse files
Merge pull request #224 from microsoft/master
merge master
parents
b9a7a95d
ae81ec47
Changes
58
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
946 additions
and
69 deletions
+946
-69
src/sdk/pynni/nni/compression/torch/pruners.py
src/sdk/pynni/nni/compression/torch/pruners.py
+383
-0
src/sdk/pynni/nni/compression/torch/quantizers.py
src/sdk/pynni/nni/compression/torch/quantizers.py
+1
-1
src/sdk/pynni/nni/compression/torch/weight_rank_filter_pruners.py
...pynni/nni/compression/torch/weight_rank_filter_pruners.py
+262
-0
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
+18
-12
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
+20
-0
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
+36
-0
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
+41
-5
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
+113
-31
src/sdk/pynni/nni/nas/pytorch/fixed.py
src/sdk/pynni/nni/nas/pytorch/fixed.py
+6
-11
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+1
-1
src/sdk/pynni/nni/nas/pytorch/spos/evolution.py
src/sdk/pynni/nni/nas/pytorch/spos/evolution.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/spos/mutator.py
src/sdk/pynni/nni/nas/pytorch/spos/mutator.py
+1
-0
src/sdk/pynni/nni/nas/pytorch/spos/trainer.py
src/sdk/pynni/nni/nas/pytorch/spos/trainer.py
+31
-0
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+3
-3
src/sdk/pynni/nni/nas/pytorch/utils.py
src/sdk/pynni/nni/nas/pytorch/utils.py
+25
-1
src/webui/src/components/Modal/Compare.tsx
src/webui/src/components/Modal/Compare.tsx
+2
-1
src/webui/src/components/overview/SuccessTable.tsx
src/webui/src/components/overview/SuccessTable.tsx
+1
-2
tools/nni_cmd/launcher.py
tools/nni_cmd/launcher.py
+1
-1
No files found.
src/sdk/pynni/nni/compression/torch/
lottery_ticket
.py
→
src/sdk/pynni/nni/compression/torch/
pruners
.py
View file @
1a5c0172
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
copy
import
logging
import
logging
import
torch
import
torch
from
.compressor
import
Pruner
from
.compressor
import
Pruner
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SlimPruner'
,
'LotteryTicketPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
class
LevelPruner
(
Pruner
):
"""
Prune to an exact pruning level specification
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if
op_name
not
in
self
.
mask_calculated_ops
:
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
mask
=
self
.
mask_dict
[
op_name
]
return
mask
class
AGP_Pruner
(
Pruner
):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch
=
config
.
get
(
'end_epoch'
,
1
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
final_sparsity
=
config
.
get
(
'final_sparsity'
,
0
)
initial_sparsity
=
config
.
get
(
'initial_sparsity'
,
0
)
if
end_epoch
<=
start_epoch
or
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your end epoch <= start epoch or initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
self
.
if_init_list
[
k
]
=
True
class
SlimPruner
(
Pruner
):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
class
LotteryTicketPruner
(
Pruner
):
class
LotteryTicketPruner
(
Pruner
):
"""
"""
...
...
src/sdk/pynni/nni/compression/torch/
builtin_
quantizers.py
→
src/sdk/pynni/nni/compression/torch/quantizers.py
View file @
1a5c0172
...
@@ -5,7 +5,7 @@ import logging
...
@@ -5,7 +5,7 @@ import logging
import
torch
import
torch
from
.compressor
import
Quantizer
,
QuantGrad
,
QuantType
from
.compressor
import
Quantizer
,
QuantGrad
,
QuantType
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
]
__all__
=
[
'NaiveQuantizer'
,
'QAT_Quantizer'
,
'DoReFaQuantizer'
,
'BNNQuantizer'
]
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
...
src/sdk/pynni/nni/compression/torch/
bu
ilt
in
_pruners.py
→
src/sdk/pynni/nni/compression/torch/
weight_rank_f
ilt
er
_pruners.py
View file @
1a5c0172
...
@@ -5,240 +5,9 @@ import logging
...
@@ -5,240 +5,9 @@ import logging
import
torch
import
torch
from
.compressor
import
Pruner
from
.compressor
import
Pruner
__all__
=
[
'LevelPruner'
,
'AGP_Pruner'
,
'SlimPruner'
,
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
,
__all__
=
[
'L1FilterPruner'
,
'L2FilterPruner'
,
'FPGMPruner'
]
'ActivationAPoZRankFilterPruner'
,
'ActivationMeanRankFilterPruner'
]
logger
=
logging
.
getLogger
(
'torch pruner'
)
class
LevelPruner
(
Pruner
):
"""
Prune to an exact pruning level specification
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
if
op_name
not
in
self
.
mask_calculated_ops
:
w_abs
=
weight
.
abs
()
k
=
int
(
weight
.
numel
()
*
config
[
'sparsity'
])
if
k
==
0
:
return
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
mask_weight
=
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)
mask
=
{
'weight'
:
mask_weight
}
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
else
:
assert
op_name
in
self
.
mask_dict
,
"op_name not in the mask_dict"
mask
=
self
.
mask_dict
[
op_name
]
return
mask
class
AGP_Pruner
(
Pruner
):
"""
An automated gradual pruning algorithm that prunes the smallest magnitude
weights to achieve a preset level of network sparsity.
Michael Zhu and Suyog Gupta, "To prune, or not to prune: exploring the
efficacy of pruning for model compression", 2017 NIPS Workshop on Machine
Learning of Phones and other Consumer Devices,
https://arxiv.org/pdf/1710.01878.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
List on pruning configs
"""
super
().
__init__
(
model
,
config_list
)
self
.
now_epoch
=
0
self
.
if_init_list
=
{}
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
if
self
.
now_epoch
>=
start_epoch
and
self
.
if_init_list
.
get
(
op_name
,
True
)
\
and
(
self
.
now_epoch
-
start_epoch
)
%
freq
==
0
:
mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
target_sparsity
=
self
.
compute_target_sparsity
(
config
)
k
=
int
(
weight
.
numel
()
*
target_sparsity
)
if
k
==
0
or
target_sparsity
>=
1
or
target_sparsity
<=
0
:
return
mask
# if we want to generate new mask, we should update weigth first
w_abs
=
weight
.
abs
()
*
mask
threshold
=
torch
.
topk
(
w_abs
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
new_mask
=
{
'weight'
:
torch
.
gt
(
w_abs
,
threshold
).
type_as
(
weight
)}
self
.
mask_dict
.
update
({
op_name
:
new_mask
})
self
.
if_init_list
.
update
({
op_name
:
False
})
else
:
new_mask
=
self
.
mask_dict
.
get
(
op_name
,
{
'weight'
:
torch
.
ones
(
weight
.
shape
).
type_as
(
weight
)})
return
new_mask
def
compute_target_sparsity
(
self
,
config
):
"""
Calculate the sparsity for pruning
Parameters
----------
config : dict
Layer's pruning config
Returns
-------
float
Target sparsity to be pruned
"""
end_epoch
=
config
.
get
(
'end_epoch'
,
1
)
start_epoch
=
config
.
get
(
'start_epoch'
,
0
)
freq
=
config
.
get
(
'frequency'
,
1
)
final_sparsity
=
config
.
get
(
'final_sparsity'
,
0
)
initial_sparsity
=
config
.
get
(
'initial_sparsity'
,
0
)
if
end_epoch
<=
start_epoch
or
initial_sparsity
>=
final_sparsity
:
logger
.
warning
(
'your end epoch <= start epoch or initial_sparsity >= final_sparsity'
)
return
final_sparsity
if
end_epoch
<=
self
.
now_epoch
:
return
final_sparsity
span
=
((
end_epoch
-
start_epoch
-
1
)
//
freq
)
*
freq
assert
span
>
0
target_sparsity
=
(
final_sparsity
+
(
initial_sparsity
-
final_sparsity
)
*
(
1.0
-
((
self
.
now_epoch
-
start_epoch
)
/
span
))
**
3
)
return
target_sparsity
def
update_epoch
(
self
,
epoch
):
"""
Update epoch
Parameters
----------
epoch : int
current training epoch
"""
if
epoch
>
0
:
self
.
now_epoch
=
epoch
for
k
in
self
.
if_init_list
.
keys
():
self
.
if_init_list
[
k
]
=
True
class
SlimPruner
(
Pruner
):
"""
A structured pruning algorithm that prunes channels by pruning the weights of BN layers.
Zhuang Liu, Jianguo Li, Zhiqiang Shen, Gao Huang, Shoumeng Yan and Changshui Zhang
"Learning Efficient Convolutional Networks through Network Slimming", 2017 ICCV
https://arxiv.org/pdf/1708.06519.pdf
"""
def
__init__
(
self
,
model
,
config_list
):
"""
Parameters
----------
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
weight_list
=
[]
if
len
(
config_list
)
>
1
:
logger
.
warning
(
'Slim pruner only supports 1 configuration'
)
config
=
config_list
[
0
]
for
(
layer
,
config
)
in
self
.
detect_modules_to_compress
():
assert
layer
.
type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
weight_list
.
append
(
layer
.
module
.
weight
.
data
.
abs
().
clone
())
all_bn_weights
=
torch
.
cat
(
weight_list
)
k
=
int
(
all_bn_weights
.
shape
[
0
]
*
config
[
'sparsity'
])
self
.
global_threshold
=
torch
.
topk
(
all_bn_weights
.
view
(
-
1
),
k
,
largest
=
False
)[
0
].
max
()
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Scale factors with the smallest absolute value in the BN layer are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
op_type
==
'BatchNorm2d'
,
'SlimPruner only supports 2d batch normalization layer pruning'
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
base_mask
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
mask
=
{
'weight'
:
base_mask
.
detach
(),
'bias'
:
base_mask
.
clone
().
detach
()}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
return
mask
w_abs
=
weight
.
abs
()
mask_weight
=
torch
.
gt
(
w_abs
,
self
.
global_threshold
).
type_as
(
weight
)
mask_bias
=
mask_weight
.
clone
()
mask
=
{
'weight'
:
mask_weight
.
detach
(),
'bias'
:
mask_bias
.
detach
()}
finally
:
self
.
mask_dict
.
update
({
layer
.
name
:
mask
})
self
.
mask_calculated_ops
.
add
(
layer
.
name
)
return
mask
logger
=
logging
.
getLogger
(
'torch weight rank filter pruners'
)
class
WeightRankFilterPruner
(
Pruner
):
class
WeightRankFilterPruner
(
Pruner
):
"""
"""
...
@@ -260,8 +29,8 @@ class WeightRankFilterPruner(Pruner):
...
@@ -260,8 +29,8 @@ class WeightRankFilterPruner(Pruner):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
# operations whose mask has been calculated
self
.
mask_calculated_ops
=
set
()
# operations whose mask has been calculated
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
r
eturn
{
'weight'
:
None
,
'bias'
:
None
}
r
aise
NotImplementedError
(
'{} get_mask is not implemented'
.
format
(
self
.
__class__
.
__name__
))
def
calc_mask
(
self
,
layer
,
config
):
def
calc_mask
(
self
,
layer
,
config
):
"""
"""
...
@@ -299,7 +68,7 @@ class WeightRankFilterPruner(Pruner):
...
@@ -299,7 +68,7 @@ class WeightRankFilterPruner(Pruner):
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
:
if
filters
<
2
or
num_prune
<
1
:
return
mask
return
mask
mask
=
self
.
_
get_mask
(
mask
,
weight
,
num_prune
)
mask
=
self
.
get_mask
(
mask
,
weight
,
num_prune
)
finally
:
finally
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
self
.
mask_calculated_ops
.
add
(
op_name
)
...
@@ -328,7 +97,7 @@ class L1FilterPruner(WeightRankFilterPruner):
...
@@ -328,7 +97,7 @@ class L1FilterPruner(WeightRankFilterPruner):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Filters with the smallest sum of its absolute kernel weights are masked.
...
@@ -376,7 +145,7 @@ class L2FilterPruner(WeightRankFilterPruner):
...
@@ -376,7 +145,7 @@ class L2FilterPruner(WeightRankFilterPruner):
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
Filters with the smallest L2 norm of the absolute kernel weights are masked.
...
@@ -422,7 +191,7 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -422,7 +191,7 @@ class FPGMPruner(WeightRankFilterPruner):
"""
"""
super
().
__init__
(
model
,
config_list
)
super
().
__init__
(
model
,
config_list
)
def
_
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
def
get_mask
(
self
,
base_mask
,
weight
,
num_prune
):
"""
"""
Calculate the mask of given layer.
Calculate the mask of given layer.
Filters with the smallest sum of its absolute kernel weights are masked.
Filters with the smallest sum of its absolute kernel weights are masked.
...
@@ -491,251 +260,3 @@ class FPGMPruner(WeightRankFilterPruner):
...
@@ -491,251 +260,3 @@ class FPGMPruner(WeightRankFilterPruner):
def
update_epoch
(
self
,
epoch
):
def
update_epoch
(
self
,
epoch
):
self
.
mask_calculated_ops
=
set
()
self
.
mask_calculated_ops
=
set
()
class
ActivationRankFilterPruner
(
Pruner
):
"""
A structured pruning base class that prunes the filters with the smallest
importance criterion in convolution layers to achieve a preset level of network sparsity.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
)
self
.
mask_calculated_ops
=
set
()
self
.
statistics_batch_num
=
statistics_batch_num
self
.
collected_activation
=
{}
self
.
hooks
=
{}
assert
activation
in
[
'relu'
,
'relu6'
]
if
activation
==
'relu'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu
elif
activation
==
'relu6'
:
self
.
activation
=
torch
.
nn
.
functional
.
relu6
else
:
self
.
activation
=
None
def
compress
(
self
):
"""
Compress the model, register a hook for collecting activations.
"""
modules_to_compress
=
self
.
detect_modules_to_compress
()
for
layer
,
config
in
modules_to_compress
:
self
.
_instrument_layer
(
layer
,
config
)
self
.
collected_activation
[
layer
.
name
]
=
[]
def
_hook
(
module_
,
input_
,
output
,
name
=
layer
.
name
):
if
len
(
self
.
collected_activation
[
name
])
<
self
.
statistics_batch_num
:
self
.
collected_activation
[
name
].
append
(
self
.
activation
(
output
.
detach
().
cpu
()))
layer
.
module
.
register_forward_hook
(
_hook
)
return
self
.
bound_model
def
_get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
return
{
'weight'
:
None
,
'bias'
:
None
}
def
calc_mask
(
self
,
layer
,
config
):
"""
Calculate the mask of given layer.
Filters with the smallest importance criterion which is calculated from the activation are masked.
Parameters
----------
layer : LayerInfo
the layer to instrument the compression operation
config : dict
layer's pruning config
Returns
-------
dict
dictionary for storing masks
"""
weight
=
layer
.
module
.
weight
.
data
op_name
=
layer
.
name
op_type
=
layer
.
type
assert
0
<=
config
.
get
(
'sparsity'
)
<
1
,
"sparsity must in the range [0, 1)"
assert
op_type
in
[
'Conv2d'
],
"only support Conv2d"
assert
op_type
in
config
.
get
(
'op_types'
)
if
op_name
in
self
.
mask_calculated_ops
:
assert
op_name
in
self
.
mask_dict
return
self
.
mask_dict
.
get
(
op_name
)
mask_weight
=
torch
.
ones
(
weight
.
size
()).
type_as
(
weight
).
detach
()
if
hasattr
(
layer
.
module
,
'bias'
)
and
layer
.
module
.
bias
is
not
None
:
mask_bias
=
torch
.
ones
(
layer
.
module
.
bias
.
size
()).
type_as
(
layer
.
module
.
bias
).
detach
()
else
:
mask_bias
=
None
mask
=
{
'weight'
:
mask_weight
,
'bias'
:
mask_bias
}
try
:
filters
=
weight
.
size
(
0
)
num_prune
=
int
(
filters
*
config
.
get
(
'sparsity'
))
if
filters
<
2
or
num_prune
<
1
or
len
(
self
.
collected_activation
[
layer
.
name
])
<
self
.
statistics_batch_num
:
return
mask
mask
=
self
.
_get_mask
(
mask
,
self
.
collected_activation
[
layer
.
name
],
num_prune
)
finally
:
if
len
(
self
.
collected_activation
[
layer
.
name
])
==
self
.
statistics_batch_num
:
self
.
mask_dict
.
update
({
op_name
:
mask
})
self
.
mask_calculated_ops
.
add
(
op_name
)
return
mask
class
ActivationAPoZRankFilterPruner
(
ActivationRankFilterPruner
):
"""
A structured pruning algorithm that prunes the filters with the
smallest APoZ(average percentage of zeros) of output activations.
Hengyuan Hu, Rui Peng, Yu-Wing Tai and Chi-Keung Tang,
"Network Trimming: A Data-Driven Neuron Pruning Approach towards Efficient Deep Architectures", ICLR 2016.
https://arxiv.org/abs/1607.03250
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
def
_get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
apoz
=
self
.
_calc_apoz
(
activations
)
prune_indices
=
torch
.
argsort
(
apoz
,
descending
=
True
)[:
num_prune
]
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
return
base_mask
def
_calc_apoz
(
self
,
activations
):
"""
Calculate APoZ(average percentage of zeros) of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's APoZ(average percentage of zeros) of the activations
"""
activations
=
torch
.
cat
(
activations
,
0
)
_eq_zero
=
torch
.
eq
(
activations
,
torch
.
zeros_like
(
activations
))
_apoz
=
torch
.
sum
(
_eq_zero
,
dim
=
(
0
,
2
,
3
))
/
torch
.
numel
(
_eq_zero
[:,
0
,
:,
:])
return
_apoz
class
ActivationMeanRankFilterPruner
(
ActivationRankFilterPruner
):
"""
A structured pruning algorithm that prunes the filters with the
smallest mean value of output activations.
Pavlo Molchanov, Stephen Tyree, Tero Karras, Timo Aila and Jan Kautz,
"Pruning Convolutional Neural Networks for Resource Efficient Inference", ICLR 2017.
https://arxiv.org/abs/1611.06440
"""
def
__init__
(
self
,
model
,
config_list
,
activation
=
'relu'
,
statistics_batch_num
=
1
):
"""
Parameters
----------
model : torch.nn.module
Model to be pruned
config_list : list
support key for each list item:
- sparsity: percentage of convolutional filters to be pruned.
activation : str
Activation function
statistics_batch_num : int
Num of batches for activation statistics
"""
super
().
__init__
(
model
,
config_list
,
activation
,
statistics_batch_num
)
def
_get_mask
(
self
,
base_mask
,
activations
,
num_prune
):
"""
Calculate the mask of given layer.
Filters with the smallest APoZ(average percentage of zeros) of output activations are masked.
Parameters
----------
base_mask : dict
The basic mask with the same shape of weight, all item in the basic mask is 1.
activations : list
Layer's output activations
num_prune : int
Num of filters to prune
Returns
-------
dict
dictionary for storing masks
"""
mean_activation
=
self
.
_cal_mean_activation
(
activations
)
prune_indices
=
torch
.
argsort
(
mean_activation
)[:
num_prune
]
for
idx
in
prune_indices
:
base_mask
[
'weight'
][
idx
]
=
0.
if
base_mask
[
'bias'
]
is
not
None
:
base_mask
[
'bias'
][
idx
]
=
0.
return
base_mask
def
_cal_mean_activation
(
self
,
activations
):
"""
Calculate mean value of activations.
Parameters
----------
activations : list
Layer's output activations
Returns
-------
torch.Tensor
Filter's mean value of the output activations
"""
activations
=
torch
.
cat
(
activations
,
0
)
mean_activation
=
torch
.
mean
(
activations
,
dim
=
(
0
,
2
,
3
))
return
mean_activation
src/sdk/pynni/nni/nas/pytorch/classic_nas/mutator.py
View file @
1a5c0172
...
@@ -10,7 +10,7 @@ import torch
...
@@ -10,7 +10,7 @@ import torch
import
nni
import
nni
from
nni.env_vars
import
trial_env_vars
from
nni.env_vars
import
trial_env_vars
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
from
nni.nas.pytorch.mutator
import
Mutator
from
nni.nas.pytorch.mutator
import
Mutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
...
@@ -104,10 +104,11 @@ class ClassicMutator(Mutator):
search_space_item : list
search_space_item : list
The list for corresponding search space.
The list for corresponding search space.
"""
"""
candidate_repr
=
search_space_item
[
"candidates"
]
multihot_list
=
[
False
]
*
mutable
.
n_candidates
multihot_list
=
[
False
]
*
mutable
.
n_candidates
for
i
,
v
in
zip
(
idx
,
value
):
for
i
,
v
in
zip
(
idx
,
value
):
assert
0
<=
i
<
mutable
.
n_candidates
and
search_space_item
[
i
]
==
v
,
\
assert
0
<=
i
<
mutable
.
n_candidates
and
candidate_repr
[
i
]
==
v
,
\
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
search_space_item
,
v
)
"Index '{}' in search space '{}' is not '{}'"
.
format
(
i
,
candidate_repr
,
v
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
assert
not
multihot_list
[
i
],
"'{}' is selected twice in '{}', which is not allowed."
.
format
(
i
,
idx
)
multihot_list
[
i
]
=
True
multihot_list
[
i
]
=
True
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
return
torch
.
tensor
(
multihot_list
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
...
@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
...
@@ -121,17 +122,20 @@ class ClassicMutator(Mutator):
self
.
_chosen_arch
.
keys
())
self
.
_chosen_arch
.
keys
())
result
=
dict
()
result
=
dict
()
for
mutable
in
self
.
mutables
:
for
mutable
in
self
.
mutables
:
assert
mutable
.
key
in
self
.
_chosen_arch
,
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
if
isinstance
(
mutable
,
(
LayerChoice
,
InputChoice
)):
data
=
self
.
_chosen_arch
[
mutable
.
key
]
assert
mutable
.
key
in
self
.
_chosen_arch
,
\
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
"Expected '{}' in chosen arch, but not found."
.
format
(
mutable
.
key
)
"'{}' is not a valid choice."
.
format
(
data
)
data
=
self
.
_chosen_arch
[
mutable
.
key
]
value
=
data
[
"_value"
]
assert
isinstance
(
data
,
dict
)
and
"_value"
in
data
and
"_idx"
in
data
,
\
idx
=
data
[
"_idx"
]
"'{}' is not a valid choice."
.
format
(
data
)
search_space_item
=
self
.
_search_space
[
mutable
.
key
][
"_value"
]
if
isinstance
(
mutable
,
LayerChoice
):
if
isinstance
(
mutable
,
LayerChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
idx
,
value
,
search_space_item
)
result
[
mutable
.
key
]
=
self
.
_sample_layer_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
InputChoice
):
elif
isinstance
(
mutable
,
InputChoice
):
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
idx
,
value
,
search_space_item
)
result
[
mutable
.
key
]
=
self
.
_sample_input_choice
(
mutable
,
data
[
"_idx"
],
data
[
"_value"
],
self
.
_search_space
[
mutable
.
key
][
"_value"
])
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during parsing choices."
,
mutable
.
key
)
else
:
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
result
return
result
...
@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
...
@@ -190,6 +194,8 @@ class ClassicMutator(Mutator):
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
search_space
[
key
]
=
{
"_type"
:
INPUT_CHOICE
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"_value"
:
{
"candidates"
:
mutable
.
choose_from
,
"n_chosen"
:
mutable
.
n_chosen
}}
"n_chosen"
:
mutable
.
n_chosen
}}
elif
isinstance
(
mutable
,
MutableScope
):
logger
.
info
(
"Mutable scope '%s' is skipped during generating search space."
,
mutable
.
key
)
else
:
else
:
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
raise
TypeError
(
"Unsupported mutable type: '%s'."
%
type
(
mutable
))
return
search_space
return
search_space
...
...
src/sdk/pynni/nni/nas/pytorch/darts/mutator.py
View file @
1a5c0172
...
@@ -14,6 +14,26 @@ _logger = logging.getLogger(__name__)
...
@@ -14,6 +14,26 @@ _logger = logging.getLogger(__name__)
class
DartsMutator
(
Mutator
):
class
DartsMutator
(
Mutator
):
"""
Connects the model in a DARTS (differentiable) way.
An extra connection is automatically inserted for each LayerChoice, when this connection is selected, there is no
op on this LayerChoice (namely a ``ZeroOp``), in which case, every element in the exported choice list is ``false``
(not chosen).
All input choice will be fully connected in the search phase. On exporting, the input choice will choose inputs based
on keys in ``choose_from``. If the keys were to be keys of LayerChoices, the top logit of the corresponding LayerChoice
will join the competition of input choice to compete against other logits. Otherwise, the logit will be assumed 0.
It's possible to cut branches by setting parameter ``choices`` in a particular position to ``-inf``. After softmax, the
value would be 0. Framework will ignore 0 values and not connect. Note that the gradient on the ``-inf`` location will
be 0. Since manipulations with ``-inf`` will be ``nan``, you need to handle the gradient update phase carefully.
Attributes
----------
choices: ParameterDict
dict that maps keys of LayerChoices to weighted-connection float tensors.
"""
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
super
().
__init__
(
model
)
super
().
__init__
(
model
)
self
.
choices
=
nn
.
ParameterDict
()
self
.
choices
=
nn
.
ParameterDict
()
...
...
src/sdk/pynni/nni/nas/pytorch/darts/trainer.py
View file @
1a5c0172
...
@@ -19,6 +19,42 @@ class DartsTrainer(Trainer):
...
@@ -19,6 +19,42 @@ class DartsTrainer(Trainer):
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
callbacks
=
None
,
arc_learning_rate
=
3.0E-4
,
unrolled
=
False
):
"""
Initialize a DartsTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : DartsMutator
Use in case of customizing your own DartsMutator. By default will instantiate a DartsMutator.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
DartsMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
...
...
src/sdk/pynni/nni/nas/pytorch/enas/mutator.py
View file @
1a5c0172
...
@@ -30,11 +30,41 @@ class StackedLSTMCell(nn.Module):
...
@@ -30,11 +30,41 @@ class StackedLSTMCell(nn.Module):
class
EnasMutator
(
Mutator
):
class
EnasMutator
(
Mutator
):
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
def
__init__
(
self
,
model
,
lstm_size
=
64
,
lstm_num_layers
=
1
,
tanh_constant
=
1.5
,
cell_exit_extra_step
=
False
,
skip_target
=
0.4
,
branch_bias
=
0.25
):
skip_target
=
0.4
,
temperature
=
None
,
branch_bias
=
0.25
,
entropy_reduction
=
"sum"
):
"""
Initialize a EnasMutator.
Parameters
----------
model : nn.Module
PyTorch model.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
super
().
__init__
(
model
)
super
().
__init__
(
model
)
self
.
lstm_size
=
lstm_size
self
.
lstm_size
=
lstm_size
self
.
lstm_num_layers
=
lstm_num_layers
self
.
lstm_num_layers
=
lstm_num_layers
self
.
tanh_constant
=
tanh_constant
self
.
tanh_constant
=
tanh_constant
self
.
temperature
=
temperature
self
.
cell_exit_extra_step
=
cell_exit_extra_step
self
.
cell_exit_extra_step
=
cell_exit_extra_step
self
.
skip_target
=
skip_target
self
.
skip_target
=
skip_target
self
.
branch_bias
=
branch_bias
self
.
branch_bias
=
branch_bias
...
@@ -45,6 +75,8 @@ class EnasMutator(Mutator):
...
@@ -45,6 +75,8 @@ class EnasMutator(Mutator):
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
v_attn
=
nn
.
Linear
(
self
.
lstm_size
,
1
,
bias
=
False
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
g_emb
=
nn
.
Parameter
(
torch
.
randn
(
1
,
self
.
lstm_size
)
*
0.1
)
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
# pylint: disable=not-callable
self
.
skip_targets
=
nn
.
Parameter
(
torch
.
tensor
([
1.0
-
self
.
skip_target
,
self
.
skip_target
]),
requires_grad
=
False
)
# pylint: disable=not-callable
assert
entropy_reduction
in
[
"sum"
,
"mean"
],
"Entropy reduction must be one of sum and mean."
self
.
entropy_reduction
=
torch
.
sum
if
entropy_reduction
==
"sum"
else
torch
.
mean
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
self
.
cross_entropy_loss
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
)
self
.
bias_dict
=
nn
.
ParameterDict
()
self
.
bias_dict
=
nn
.
ParameterDict
()
...
@@ -110,15 +142,17 @@ class EnasMutator(Mutator):
...
@@ -110,15 +142,17 @@ class EnasMutator(Mutator):
def
_sample_layer_choice
(
self
,
mutable
):
def
_sample_layer_choice
(
self
,
mutable
):
self
.
_lstm_next_step
()
self
.
_lstm_next_step
()
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
logit
=
self
.
soft
(
self
.
_h
[
-
1
])
if
self
.
temperature
is
not
None
:
logit
/=
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
logit
=
self
.
tanh_constant
*
torch
.
tanh
(
logit
)
if
mutable
.
key
in
self
.
bias_dict
:
if
mutable
.
key
in
self
.
bias_dict
:
logit
+=
self
.
bias_dict
[
mutable
.
key
]
logit
+=
self
.
bias_dict
[
mutable
.
key
]
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
branch_id
=
torch
.
multinomial
(
F
.
softmax
(
logit
,
dim
=-
1
),
1
).
view
(
-
1
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
branch_id
)
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
self
.
_inputs
=
self
.
embedding
(
branch_id
)
self
.
_inputs
=
self
.
embedding
(
branch_id
)
return
F
.
one_hot
(
branch_id
,
num_classes
=
self
.
max_layer_choice
).
bool
().
view
(
-
1
)
return
F
.
one_hot
(
branch_id
,
num_classes
=
self
.
max_layer_choice
).
bool
().
view
(
-
1
)
...
@@ -133,6 +167,8 @@ class EnasMutator(Mutator):
...
@@ -133,6 +167,8 @@ class EnasMutator(Mutator):
query
=
torch
.
cat
(
query
,
0
)
query
=
torch
.
cat
(
query
,
0
)
query
=
torch
.
tanh
(
query
+
self
.
attn_query
(
self
.
_h
[
-
1
]))
query
=
torch
.
tanh
(
query
+
self
.
attn_query
(
self
.
_h
[
-
1
]))
query
=
self
.
v_attn
(
query
)
query
=
self
.
v_attn
(
query
)
if
self
.
temperature
is
not
None
:
query
/=
self
.
temperature
if
self
.
tanh_constant
is
not
None
:
if
self
.
tanh_constant
is
not
None
:
query
=
self
.
tanh_constant
*
torch
.
tanh
(
query
)
query
=
self
.
tanh_constant
*
torch
.
tanh
(
query
)
...
@@ -153,7 +189,7 @@ class EnasMutator(Mutator):
...
@@ -153,7 +189,7 @@ class EnasMutator(Mutator):
log_prob
=
self
.
cross_entropy_loss
(
logit
,
index
)
log_prob
=
self
.
cross_entropy_loss
(
logit
,
index
)
self
.
_inputs
=
anchors
[
index
.
item
()]
self
.
_inputs
=
anchors
[
index
.
item
()]
self
.
sample_log_prob
+=
torch
.
sum
(
log_prob
)
self
.
sample_log_prob
+=
self
.
entropy_reduction
(
log_prob
)
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
entropy
=
(
log_prob
*
torch
.
exp
(
-
log_prob
)).
detach
()
# pylint: disable=invalid-unary-operand-type
self
.
sample_entropy
+=
torch
.
sum
(
entropy
)
self
.
sample_entropy
+=
self
.
entropy_reduction
(
entropy
)
return
skip
.
bool
()
return
skip
.
bool
()
src/sdk/pynni/nni/nas/pytorch/enas/trainer.py
View file @
1a5c0172
...
@@ -2,11 +2,14 @@
...
@@ -2,11 +2,14 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
import
logging
import
logging
from
itertools
import
cycle
import
torch
import
torch
import
torch.nn
as
nn
import
torch.optim
as
optim
import
torch.optim
as
optim
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
nni.nas.pytorch.utils
import
AverageMeterGroup
,
to_device
from
.mutator
import
EnasMutator
from
.mutator
import
EnasMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -16,13 +19,68 @@ class EnasTrainer(Trainer):
...
@@ -16,13 +19,68 @@ class EnasTrainer(Trainer):
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
def
__init__
(
self
,
model
,
loss
,
metrics
,
reward_function
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
entropy_weight
=
0.0001
,
skip_weight
=
0.8
,
baseline_decay
=
0.999
,
child_steps
=
500
,
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
):
mutator_lr
=
0.00035
,
mutator_steps_aggregate
=
20
,
mutator_steps
=
50
,
aux_weight
=
0.4
,
test_arc_per_epoch
=
1
):
"""
Initialize an EnasTrainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset_train : Dataset
Dataset for training. Will be split for training weights and architecture weights.
dataset_valid : Dataset
Dataset for testing.
mutator : EnasMutator
Use when customizing your own mutator or a mutator with customized parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
callbacks : list of Callback
list of callbacks to trigger at events.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
child_steps : int
How many mini-batches for model training per epoch.
mutator_lr : float
Learning rate for RL controller.
mutator_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
mutator_steps : int
Number of mini-batches for each epoch of RL controller learning.
aux_weight : float
Weight of auxiliary head loss. ``aux_weight * aux_loss`` will be added to total loss.
test_arc_per_epoch : int
How many architectures are chosen for direct test after each epoch.
"""
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
),
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
EnasMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
loss
,
metrics
,
optimizer
,
num_epochs
,
dataset_train
,
dataset_valid
,
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
batch_size
,
workers
,
device
,
log_frequency
,
callbacks
)
self
.
reward_function
=
reward_function
self
.
reward_function
=
reward_function
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
mutator_optim
=
optim
.
Adam
(
self
.
mutator
.
parameters
(),
lr
=
mutator_lr
)
self
.
batch_size
=
batch_size
self
.
workers
=
workers
self
.
entropy_weight
=
entropy_weight
self
.
entropy_weight
=
entropy_weight
self
.
skip_weight
=
skip_weight
self
.
skip_weight
=
skip_weight
...
@@ -30,32 +88,40 @@ class EnasTrainer(Trainer):
...
@@ -30,32 +88,40 @@ class EnasTrainer(Trainer):
self
.
baseline
=
0.
self
.
baseline
=
0.
self
.
mutator_steps_aggregate
=
mutator_steps_aggregate
self
.
mutator_steps_aggregate
=
mutator_steps_aggregate
self
.
mutator_steps
=
mutator_steps
self
.
mutator_steps
=
mutator_steps
self
.
child_steps
=
child_steps
self
.
aux_weight
=
aux_weight
self
.
aux_weight
=
aux_weight
self
.
test_arc_per_epoch
=
test_arc_per_epoch
self
.
init_dataloader
()
def
init_dataloader
(
self
):
n_train
=
len
(
self
.
dataset_train
)
n_train
=
len
(
self
.
dataset_train
)
split
=
n_train
//
10
split
=
n_train
//
10
indices
=
list
(
range
(
n_train
))
indices
=
list
(
range
(
n_train
))
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
train_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[:
-
split
])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
valid_sampler
=
torch
.
utils
.
data
.
sampler
.
SubsetRandomSampler
(
indices
[
-
split
:])
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
self
.
train_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
batch_size
=
self
.
batch_size
,
sampler
=
train_sampler
,
sampler
=
train_sampler
,
num_workers
=
workers
)
num_workers
=
self
.
workers
)
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
self
.
valid_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_train
,
batch_size
=
batch_size
,
batch_size
=
self
.
batch_size
,
sampler
=
valid_sampler
,
sampler
=
valid_sampler
,
num_workers
=
workers
)
num_workers
=
self
.
workers
)
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
self
.
test_loader
=
torch
.
utils
.
data
.
DataLoader
(
self
.
dataset_valid
,
batch_size
=
batch_size
,
batch_size
=
self
.
batch_size
,
num_workers
=
workers
)
num_workers
=
self
.
workers
)
self
.
train_loader
=
cycle
(
self
.
train_loader
)
self
.
valid_loader
=
cycle
(
self
.
valid_loader
)
def
train_one_epoch
(
self
,
epoch
):
def
train_one_epoch
(
self
,
epoch
):
# Sample model and train
# Sample model and train
self
.
model
.
train
()
self
.
model
.
train
()
self
.
mutator
.
eval
()
self
.
mutator
.
eval
()
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
train_loader
):
for
step
in
range
(
1
,
self
.
child_steps
+
1
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
next
(
self
.
train_loader
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
optimizer
.
zero_grad
()
self
.
optimizer
.
zero_grad
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
...
@@ -71,55 +137,71 @@ class EnasTrainer(Trainer):
...
@@ -71,55 +137,71 @@ class EnasTrainer(Trainer):
loss
=
self
.
loss
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
loss
=
loss
+
self
.
aux_weight
*
aux_loss
loss
=
loss
+
self
.
aux_weight
*
aux_loss
loss
.
backward
()
loss
.
backward
()
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
parameters
(),
5.
)
self
.
optimizer
.
step
()
self
.
optimizer
.
step
()
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Model Epoch [%
s
/%
s
] Step [%
s
/%
s
] %s"
,
epoch
+
1
,
logger
.
info
(
"Model Epoch [%
d
/%
d
] Step [%
d
/%
d
] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
)
,
meters
)
self
.
num_epochs
,
step
,
self
.
child_steps
,
meters
)
# Train sampler (mutator)
# Train sampler (mutator)
self
.
model
.
eval
()
self
.
model
.
eval
()
self
.
mutator
.
train
()
self
.
mutator
.
train
()
meters
=
AverageMeterGroup
()
meters
=
AverageMeterGroup
()
mutator_step
,
total_mutator_steps
=
0
,
self
.
mutator_steps
*
self
.
mutator_steps_aggregate
for
mutator_step
in
range
(
1
,
self
.
mutator_steps
+
1
):
while
mutator_step
<
total_mutator_steps
:
self
.
mutator_optim
.
zero_grad
()
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
for
step
in
range
(
1
,
self
.
mutator_steps_aggregate
+
1
):
x
,
y
=
x
.
to
(
self
.
device
),
y
.
to
(
self
.
device
)
x
,
y
=
next
(
self
.
valid_loader
)
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
mutator
.
reset
()
self
.
mutator
.
reset
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
logits
=
self
.
model
(
x
)
logits
=
self
.
model
(
x
)
metrics
=
self
.
metrics
(
logits
,
y
)
metrics
=
self
.
metrics
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
reward
=
self
.
reward_function
(
logits
,
y
)
if
self
.
entropy_weight
is
not
None
:
if
self
.
entropy_weight
:
reward
+=
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
reward
+=
self
.
entropy_weight
*
self
.
mutator
.
sample_entropy
.
item
()
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
*
self
.
baseline_decay
+
reward
*
(
1
-
self
.
baseline_decay
)
self
.
baseline
=
self
.
baseline
.
detach
().
item
()
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
loss
=
self
.
mutator
.
sample_log_prob
*
(
reward
-
self
.
baseline
)
if
self
.
skip_weight
:
if
self
.
skip_weight
:
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
loss
+=
self
.
skip_weight
*
self
.
mutator
.
sample_skip_penalty
metrics
[
"reward"
]
=
reward
metrics
[
"reward"
]
=
reward
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"loss"
]
=
loss
.
item
()
metrics
[
"ent"
]
=
self
.
mutator
.
sample_entropy
.
item
()
metrics
[
"ent"
]
=
self
.
mutator
.
sample_entropy
.
item
()
metrics
[
"log_prob"
]
=
self
.
mutator
.
sample_log_prob
.
item
()
metrics
[
"baseline"
]
=
self
.
baseline
metrics
[
"baseline"
]
=
self
.
baseline
metrics
[
"skip"
]
=
self
.
mutator
.
sample_skip_penalty
metrics
[
"skip"
]
=
self
.
mutator
.
sample_skip_penalty
loss
=
loss
/
self
.
mutator_steps_aggregate
loss
/
=
self
.
mutator_steps_aggregate
loss
.
backward
()
loss
.
backward
()
meters
.
update
(
metrics
)
meters
.
update
(
metrics
)
if
mutator_step
%
self
.
mutator_steps_aggregate
==
0
:
cur_step
=
step
+
(
mutator_step
-
1
)
*
self
.
mutator_steps_aggregate
self
.
mutator_optim
.
step
()
if
self
.
log_frequency
is
not
None
and
cur_step
%
self
.
log_frequency
==
0
:
self
.
mutator_optim
.
zero_grad
()
logger
.
info
(
"RL Epoch [%d/%d] Step [%d/%d] [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
mutator_step
,
self
.
mutator_steps
,
step
,
self
.
mutator_steps_aggregate
,
meters
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
nn
.
utils
.
clip_grad_norm_
(
self
.
mutator
.
parameters
(),
5.
)
logger
.
info
(
"RL Epoch [%s/%s] Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
self
.
mutator_optim
.
step
()
mutator_step
//
self
.
mutator_steps_aggregate
+
1
,
self
.
mutator_steps
,
meters
)
mutator_step
+=
1
if
mutator_step
>=
total_mutator_steps
:
break
def
validate_one_epoch
(
self
,
epoch
):
def
validate_one_epoch
(
self
,
epoch
):
pass
with
torch
.
no_grad
():
for
arc_id
in
range
(
self
.
test_arc_per_epoch
):
meters
=
AverageMeterGroup
()
for
x
,
y
in
self
.
test_loader
:
x
,
y
=
to_device
(
x
,
self
.
device
),
to_device
(
y
,
self
.
device
)
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
if
isinstance
(
logits
,
tuple
):
logits
,
_
=
logits
metrics
=
self
.
metrics
(
logits
,
y
)
loss
=
self
.
loss
(
logits
,
y
)
metrics
[
"loss"
]
=
loss
.
item
()
meters
.
update
(
metrics
)
logger
.
info
(
"Test Epoch [%d/%d] Arc [%d/%d] Summary %s"
,
epoch
+
1
,
self
.
num_epochs
,
arc_id
+
1
,
self
.
test_arc_per_epoch
,
meters
.
summary
())
src/sdk/pynni/nni/nas/pytorch/fixed.py
View file @
1a5c0172
...
@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
...
@@ -41,18 +41,18 @@ class FixedArchitecture(Mutator):
return
self
.
_fixed_arc
return
self
.
_fixed_arc
def
_encode_tensor
(
data
,
device
):
def
_encode_tensor
(
data
):
if
isinstance
(
data
,
list
):
if
isinstance
(
data
,
list
):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
if
all
(
map
(
lambda
o
:
isinstance
(
o
,
bool
),
data
)):
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
,
device
=
device
)
# pylint: disable=not-callable
return
torch
.
tensor
(
data
,
dtype
=
torch
.
bool
)
# pylint: disable=not-callable
else
:
else
:
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
,
device
=
device
)
# pylint: disable=not-callable
return
torch
.
tensor
(
data
,
dtype
=
torch
.
float
)
# pylint: disable=not-callable
if
isinstance
(
data
,
dict
):
if
isinstance
(
data
,
dict
):
return
{
k
:
_encode_tensor
(
v
,
device
)
for
k
,
v
in
data
.
items
()}
return
{
k
:
_encode_tensor
(
v
)
for
k
,
v
in
data
.
items
()}
return
data
return
data
def
apply_fixed_architecture
(
model
,
fixed_arc_path
,
device
=
None
):
def
apply_fixed_architecture
(
model
,
fixed_arc_path
):
"""
"""
Load architecture from `fixed_arc_path` and apply to model.
Load architecture from `fixed_arc_path` and apply to model.
...
@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
...
@@ -62,21 +62,16 @@ def apply_fixed_architecture(model, fixed_arc_path, device=None):
Model with mutables.
Model with mutables.
fixed_arc_path : str
fixed_arc_path : str
Path to the JSON that stores the architecture.
Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns
Returns
-------
-------
FixedArchitecture
FixedArchitecture
"""
"""
if
device
is
None
:
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
if
isinstance
(
fixed_arc_path
,
str
):
if
isinstance
(
fixed_arc_path
,
str
):
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
with
open
(
fixed_arc_path
,
"r"
)
as
f
:
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
json
.
load
(
f
)
fixed_arc
=
_encode_tensor
(
fixed_arc
,
device
)
fixed_arc
=
_encode_tensor
(
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
=
FixedArchitecture
(
model
,
fixed_arc
)
architecture
.
to
(
device
)
architecture
.
reset
()
architecture
.
reset
()
return
architecture
return
architecture
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
1a5c0172
...
@@ -159,7 +159,7 @@ class InputChoice(Mutable):
...
@@ -159,7 +159,7 @@ class InputChoice(Mutable):
"than number of candidates."
"than number of candidates."
self
.
n_candidates
=
n_candidates
self
.
n_candidates
=
n_candidates
self
.
choose_from
=
choose_from
self
.
choose_from
=
choose_from
.
copy
()
self
.
n_chosen
=
n_chosen
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
...
...
src/sdk/pynni/nni/nas/pytorch/spos/evolution.py
View file @
1a5c0172
...
@@ -211,6 +211,7 @@ class SPOSEvolution(Tuner):
...
@@ -211,6 +211,7 @@ class SPOSEvolution(Tuner):
Parameters
Parameters
----------
----------
result : dict
result : dict
Chosen architectures to be exported.
"""
"""
os
.
makedirs
(
"checkpoints"
,
exist_ok
=
True
)
os
.
makedirs
(
"checkpoints"
,
exist_ok
=
True
)
for
i
,
cand
in
enumerate
(
result
):
for
i
,
cand
in
enumerate
(
result
):
...
...
src/sdk/pynni/nni/nas/pytorch/spos/mutator.py
View file @
1a5c0172
...
@@ -17,6 +17,7 @@ class SPOSSupernetTrainingMutator(RandomMutator):
...
@@ -17,6 +17,7 @@ class SPOSSupernetTrainingMutator(RandomMutator):
Parameters
Parameters
----------
----------
model : nn.Module
model : nn.Module
PyTorch model.
flops_func : callable
flops_func : callable
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
Callable that takes a candidate from `sample_search` and returns its candidate. When `flops_func`
is None, functions related to flops will be deactivated.
is None, functions related to flops will be deactivated.
...
...
src/sdk/pynni/nni/nas/pytorch/spos/trainer.py
View file @
1a5c0172
...
@@ -21,6 +21,37 @@ class SPOSSupernetTrainer(Trainer):
...
@@ -21,6 +21,37 @@ class SPOSSupernetTrainer(Trainer):
optimizer
,
num_epochs
,
train_loader
,
valid_loader
,
optimizer
,
num_epochs
,
train_loader
,
valid_loader
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
mutator
=
None
,
batch_size
=
64
,
workers
=
4
,
device
=
None
,
log_frequency
=
None
,
callbacks
=
None
):
callbacks
=
None
):
"""
Parameters
----------
model : nn.Module
Model with mutables.
mutator : Mutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterable
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
dataset_valid : iterable
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
assert
torch
.
cuda
.
is_available
()
assert
torch
.
cuda
.
is_available
()
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
SPOSSupernetTrainingMutator
(
model
),
super
().
__init__
(
model
,
mutator
if
mutator
is
not
None
else
SPOSSupernetTrainingMutator
(
model
),
loss
,
metrics
,
optimizer
,
num_epochs
,
None
,
None
,
loss
,
metrics
,
optimizer
,
num_epochs
,
None
,
None
,
...
...
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
1a5c0172
...
@@ -52,7 +52,7 @@ class Trainer(BaseTrainer):
...
@@ -52,7 +52,7 @@ class Trainer(BaseTrainer):
workers : int
workers : int
Number of workers used in data preprocessing.
Number of workers used in data preprocessing.
device : torch.device
device : torch.device
Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will
Device object. Either
`
`torch.device("cuda")`
`
or
``
torch.device("cpu")`
`
. When
`
`None`
`
, trainer will
automatic detects GPU and selects GPU first.
automatic detects GPU and selects GPU first.
log_frequency : int
log_frequency : int
Number of mini-batches to log metrics.
Number of mini-batches to log metrics.
...
@@ -96,12 +96,12 @@ class Trainer(BaseTrainer):
...
@@ -96,12 +96,12 @@ class Trainer(BaseTrainer):
callback
.
on_epoch_begin
(
epoch
)
callback
.
on_epoch_begin
(
epoch
)
# training
# training
_logger
.
info
(
"Epoch %d Training"
,
epoch
)
_logger
.
info
(
"Epoch %d Training"
,
epoch
+
1
)
self
.
train_one_epoch
(
epoch
)
self
.
train_one_epoch
(
epoch
)
if
validate
:
if
validate
:
# validation
# validation
_logger
.
info
(
"Epoch %d Validating"
,
epoch
)
_logger
.
info
(
"Epoch %d Validating"
,
epoch
+
1
)
self
.
validate_one_epoch
(
epoch
)
self
.
validate_one_epoch
(
epoch
)
for
callback
in
self
.
callbacks
:
for
callback
in
self
.
callbacks
:
...
...
src/sdk/pynni/nni/nas/pytorch/utils.py
View file @
1a5c0172
...
@@ -4,6 +4,8 @@
...
@@ -4,6 +4,8 @@
import
logging
import
logging
from
collections
import
OrderedDict
from
collections
import
OrderedDict
import
torch
_counter
=
0
_counter
=
0
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -15,7 +17,22 @@ def global_mutable_counting():
...
@@ -15,7 +17,22 @@ def global_mutable_counting():
return
_counter
return
_counter
def
to_device
(
obj
,
device
):
if
torch
.
is_tensor
(
obj
):
return
obj
.
to
(
device
)
if
isinstance
(
obj
,
tuple
):
return
tuple
(
to_device
(
t
,
device
)
for
t
in
obj
)
if
isinstance
(
obj
,
list
):
return
[
to_device
(
t
,
device
)
for
t
in
obj
]
if
isinstance
(
obj
,
dict
):
return
{
k
:
to_device
(
v
,
device
)
for
k
,
v
in
obj
.
items
()}
if
isinstance
(
obj
,
(
int
,
float
,
str
)):
return
obj
raise
ValueError
(
"'%s' has unsupported type '%s'"
%
(
obj
,
type
(
obj
)))
class
AverageMeterGroup
:
class
AverageMeterGroup
:
"""Average meter group for multiple average meters"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
meters
=
OrderedDict
()
self
.
meters
=
OrderedDict
()
...
@@ -33,7 +50,10 @@ class AverageMeterGroup:
...
@@ -33,7 +50,10 @@ class AverageMeterGroup:
return
self
.
meters
[
item
]
return
self
.
meters
[
item
]
def
__str__
(
self
):
def
__str__
(
self
):
return
" "
.
join
(
str
(
v
)
for
_
,
v
in
self
.
meters
.
items
())
return
" "
.
join
(
str
(
v
)
for
v
in
self
.
meters
.
values
())
def
summary
(
self
):
return
" "
.
join
(
v
.
summary
()
for
v
in
self
.
meters
.
values
())
class
AverageMeter
:
class
AverageMeter
:
...
@@ -72,6 +92,10 @@ class AverageMeter:
...
@@ -72,6 +92,10 @@ class AverageMeter:
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
fmtstr
=
'{name} {val'
+
self
.
fmt
+
'} ({avg'
+
self
.
fmt
+
'})'
return
fmtstr
.
format
(
**
self
.
__dict__
)
return
fmtstr
.
format
(
**
self
.
__dict__
)
def
summary
(
self
):
fmtstr
=
'{name}: {avg'
+
self
.
fmt
+
'}'
return
fmtstr
.
format
(
**
self
.
__dict__
)
class
StructuredMutableTreeNode
:
class
StructuredMutableTreeNode
:
"""
"""
...
...
src/webui/src/components/Modal/Compare.tsx
View file @
1a5c0172
...
@@ -91,7 +91,8 @@ class Compare extends React.Component<CompareProps, {}> {
...
@@ -91,7 +91,8 @@ class Compare extends React.Component<CompareProps, {}> {
},
},
yAxis
:
{
yAxis
:
{
type
:
'
value
'
,
type
:
'
value
'
,
name
:
'
Metric
'
name
:
'
Metric
'
,
scale
:
true
},
},
series
:
trialIntermediate
series
:
trialIntermediate
};
};
...
...
src/webui/src/components/overview/SuccessTable.tsx
View file @
1a5c0172
...
@@ -28,12 +28,11 @@ class SuccessTable extends React.Component<SuccessTableProps, {}> {
...
@@ -28,12 +28,11 @@ class SuccessTable extends React.Component<SuccessTableProps, {}> {
{
{
title
:
'
Trial No.
'
,
title
:
'
Trial No.
'
,
dataIndex
:
'
sequenceId
'
,
dataIndex
:
'
sequenceId
'
,
width
:
140
,
className
:
'
tableHead
'
className
:
'
tableHead
'
},
{
},
{
title
:
'
ID
'
,
title
:
'
ID
'
,
dataIndex
:
'
id
'
,
dataIndex
:
'
id
'
,
width
:
6
0
,
width
:
8
0
,
className
:
'
tableHead leftTitle
'
,
className
:
'
tableHead leftTitle
'
,
render
:
(
text
:
string
,
record
:
TableRecord
):
React
.
ReactNode
=>
{
render
:
(
text
:
string
,
record
:
TableRecord
):
React
.
ReactNode
=>
{
return
(
return
(
...
...
tools/nni_cmd/launcher.py
View file @
1a5c0172
...
@@ -517,7 +517,7 @@ def manage_stopped_experiment(args, mode):
...
@@ -517,7 +517,7 @@ def manage_stopped_experiment(args, mode):
experiment_id
=
None
experiment_id
=
None
#find the latest stopped experiment
#find the latest stopped experiment
if
not
args
.
id
:
if
not
args
.
id
:
print_error
(
'Please set experiment id!
\n
You could use
\'
nnictl {0}
{
id
}
\'
to {0} a stopped experiment!
\n
'
\
print_error
(
'Please set experiment id!
\n
You could use
\'
nnictl {0} id
\'
to {0} a stopped experiment!
\n
'
\
'You could use
\'
nnictl experiment list --all
\'
to show all experiments!'
.
format
(
mode
))
'You could use
\'
nnictl experiment list --all
\'
to show all experiments!'
.
format
(
mode
))
exit
(
1
)
exit
(
1
)
else
:
else
:
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment