Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3ec26b40
Unverified
Commit
3ec26b40
authored
Dec 11, 2020
by
liuzhe-lz
Committed by
GitHub
Dec 11, 2020
Browse files
Merge master into dev-retiarii (#3178)
parent
d165905d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1112 additions
and
383 deletions
+1112
-383
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+62
-11
nni/compression/pytorch/pruning/__init__.py
nni/compression/pytorch/pruning/__init__.py
+4
-0
nni/compression/pytorch/pruning/apply_compression.py
nni/compression/pytorch/pruning/apply_compression.py
+29
-0
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+118
-11
nni/compression/pytorch/speedup/infer_shape.py
nni/compression/pytorch/speedup/infer_shape.py
+106
-7
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+49
-22
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+28
-14
nni/experiment/__init__.py
nni/experiment/__init__.py
+1
-1
nni/experiment/config/__init__.py
nni/experiment/config/__init__.py
+4
-2
nni/experiment/config/base.py
nni/experiment/config/base.py
+145
-107
nni/experiment/config/common.py
nni/experiment/config/common.py
+145
-0
nni/experiment/config/convert.py
nni/experiment/config/convert.py
+228
-0
nni/experiment/config/local.py
nni/experiment/config/local.py
+16
-30
nni/experiment/config/util.py
nni/experiment/config/util.py
+54
-0
nni/experiment/experiment.py
nni/experiment/experiment.py
+52
-117
nni/experiment/launcher.py
nni/experiment/launcher.py
+26
-12
nni/experiment/nni_client.py
nni/experiment/nni_client.py
+0
-33
nni/experiment/pipe.py
nni/experiment/pipe.py
+10
-10
nni/runtime/log.py
nni/runtime/log.py
+31
-2
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+4
-4
No files found.
nni/compression/pytorch/compressor.py
View file @
3ec26b40
...
@@ -580,17 +580,55 @@ class QuantType:
...
@@ -580,17 +580,55 @@ class QuantType:
"""
"""
Enum class for quantization type.
Enum class for quantization type.
"""
"""
QUANT_INPUT
=
0
QUANT_INPUT
=
'input'
QUANT_WEIGHT
=
1
QUANT_WEIGHT
=
'weight'
QUANT_OUTPUT
=
2
QUANT_OUTPUT
=
'output'
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
Base class for overriding backward function of quantization operation.
Base class for overriding backward function of quantization operation.
"""
"""
@
classmethod
def
_quantize
(
cls
,
x
,
scale
,
zero_point
):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
tensor to be quantized
scale : Tensor
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
quantized x without clamped
"""
return
((
x
/
scale
)
+
zero_point
).
round
()
@
classmethod
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
Get bit for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
n-bits for quantization configuration
"""
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
):
def
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
):
"""
"""
This method should be overrided by subclass to provide customized backward function,
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
default implementation is Straight-Through Estimator
...
@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function):
...
@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function):
input of quantization operation
input of quantization operation
grad_output : Tensor
grad_output : Tensor
gradient of the output of quantization operation
gradient of the output of quantization operation
quant_type : QuantType
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
Returns
-------
-------
tensor
tensor
gradient of the input of quantization operation
gradient of the input of quantization operation
"""
"""
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
return
grad_output
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
**
kwargs
):
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
if
quant_type
==
QuantType
.
QUANT_INPUT
:
if
quant_type
==
QuantType
.
QUANT_INPUT
:
return
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
return
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
return
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
"unrecognized QuantType."
)
raise
ValueError
(
"unrecognized QuantType."
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
torch
.
Tensor
([
0
],
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
],
device
=
tensor
.
device
)
ctx
.
save_for_backward
(
tensor
,
wrapper
.
module
.
scale
,
wrapper
.
module
.
zero_point
,
qmin
,
qmax
)
return
output
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
tensor
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
return
output
,
None
,
None
,
None
def
_check_weight
(
module
):
def
_check_weight
(
module
):
...
...
nni/compression/pytorch/pruning/__init__.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.apply_compression
import
apply_compression_results
nni/compression/pytorch/pruning/apply_compression.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
logger
=
logging
.
getLogger
(
'torch apply compression'
)
def
apply_compression_results
(
model
,
masks_file
,
map_location
=
None
):
"""
Apply the masks from ```masks_file``` to the model
Note: this API is for inference, because it simply multiplies weights with
corresponding masks when this API is called.
Parameters
----------
model : torch.nn.Module
The model to be compressed
masks_file : str
The path of the mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
masks
=
torch
.
load
(
masks_file
,
map_location
)
for
name
,
module
in
model
.
named_modules
():
if
name
in
masks
:
module
.
weight
.
data
=
module
.
weight
.
data
.
mul_
(
masks
[
name
][
'weight'
])
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
and
'bias'
in
masks
[
name
]:
module
.
bias
.
data
=
module
.
bias
.
data
.
mul_
(
masks
[
name
][
'bias'
])
\ No newline at end of file
nni/compression/pytorch/speedup/compress_modules.py
View file @
3ec26b40
...
@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__)
...
@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__)
replace_module
=
{
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'ConvTranspose2d'
:
lambda
module
,
mask
:
replace_convtranspose2d
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
...
@@ -22,6 +23,7 @@ replace_module = {
...
@@ -22,6 +23,7 @@ replace_module = {
'Dropout3d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
)
'Dropout3d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
)
}
}
def
no_replace
(
module
,
mask
):
def
no_replace
(
module
,
mask
):
"""
"""
No need to replace
No need to replace
...
@@ -29,6 +31,7 @@ def no_replace(module, mask):
...
@@ -29,6 +31,7 @@ def no_replace(module, mask):
_logger
.
debug
(
"no need to replace"
)
_logger
.
debug
(
"no need to replace"
)
return
module
return
module
def
replace_linear
(
linear
,
mask
):
def
replace_linear
(
linear
,
mask
):
"""
"""
Parameters
Parameters
...
@@ -54,11 +57,13 @@ def replace_linear(linear, mask):
...
@@ -54,11 +57,13 @@ def replace_linear(linear, mask):
out_features
=
linear
.
out_features
,
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
)
bias
=
linear
.
bias
is
not
None
)
new_linear
.
to
(
linear
.
weight
.
device
)
new_linear
.
to
(
linear
.
weight
.
device
)
new_linear
.
weight
.
data
=
torch
.
index_select
(
linear
.
weight
.
data
,
-
1
,
index
.
to
(
linear
.
weight
.
device
))
new_linear
.
weight
.
data
=
torch
.
index_select
(
linear
.
weight
.
data
,
-
1
,
index
.
to
(
linear
.
weight
.
device
))
if
linear
.
bias
is
not
None
:
if
linear
.
bias
is
not
None
:
new_linear
.
bias
.
data
.
copy_
(
linear
.
bias
.
data
)
new_linear
.
bias
.
data
.
copy_
(
linear
.
bias
.
data
)
return
new_linear
return
new_linear
def
replace_batchnorm2d
(
norm
,
mask
):
def
replace_batchnorm2d
(
norm
,
mask
):
"""
"""
Parameters
Parameters
...
@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask):
...
@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask):
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
index
)
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
index
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
index
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
index
)
if
norm
.
track_running_stats
:
if
norm
.
track_running_stats
:
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
index
)
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
index
)
norm
.
running_mean
.
data
,
0
,
index
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
index
)
return
new_norm
return
new_norm
def
replace_conv2d
(
conv
,
mask
):
def
replace_conv2d
(
conv
,
mask
):
"""
"""
Parameters
Parameters
...
@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask):
...
@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask):
# remove groups for depthwise layers
# remove groups for depthwise layers
assert
in_channels
==
out_channels
assert
in_channels
==
out_channels
groups
=
in_channels
groups
=
in_channels
_logger
.
debug
(
"replace conv2d %s with in_channels: %d, out_channels: %d"
,
mask
.
module_name
,
in_channels
,
out_channels
)
_logger
.
debug
(
"replace conv2d %s with in_channels: %d, out_channels: %d"
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
out_channels
=
out_channels
,
kernel_size
=
conv
.
kernel_size
,
kernel_size
=
conv
.
kernel_size
,
...
@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask):
...
@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask):
tmp_weight_data
=
tmp_bias_data
=
None
tmp_weight_data
=
tmp_bias_data
=
None
if
mask
.
output_mask
is
not
None
:
if
mask
.
output_mask
is
not
None
:
tmp_weight_data
=
torch
.
index_select
(
conv
.
weight
.
data
,
0
,
out_channels_index
)
tmp_weight_data
=
torch
.
index_select
(
conv
.
weight
.
data
,
0
,
out_channels_index
)
if
conv
.
bias
is
not
None
:
if
conv
.
bias
is
not
None
:
tmp_bias_data
=
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
out_channels_index
)
tmp_bias_data
=
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
out_channels_index
)
else
:
else
:
tmp_weight_data
=
conv
.
weight
.
data
tmp_weight_data
=
conv
.
weight
.
data
# For the convolutional layers that have more than one group
# For the convolutional layers that have more than one group
...
@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask):
...
@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask):
for
groupid
in
range
(
conv
.
groups
):
for
groupid
in
range
(
conv
.
groups
):
start
=
groupid
*
input_step
start
=
groupid
*
input_step
end
=
(
groupid
+
1
)
*
input_step
end
=
(
groupid
+
1
)
*
input_step
current_input_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
in_channels_index
.
tolist
()))
current_input_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
in_channels_index
.
tolist
()))
if
not
current_input_index
:
if
not
current_input_index
:
# there is no kept channel in current group
# there is no kept channel in current group
continue
# TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise
Exception
(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily"
)
# shift the global index into the group index
# shift the global index into the group index
current_input_index
=
[
x
-
start
for
x
in
current_input_index
]
current_input_index
=
[
x
-
start
for
x
in
current_input_index
]
# if the groups is larger than 1, the input channels of each
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
# group should be pruned evenly.
assert
len
(
current_input_index
)
==
in_channels_group
,
\
assert
len
(
current_input_index
)
==
in_channels_group
,
\
'Input channels of each group are not pruned evenly'
'Input channels of each group are not pruned evenly'
current_input_index
=
torch
.
tensor
(
current_input_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
current_input_index
=
torch
.
tensor
(
current_input_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
f_start
=
groupid
*
filter_step
f_start
=
groupid
*
filter_step
f_end
=
(
groupid
+
1
)
*
filter_step
f_end
=
(
groupid
+
1
)
*
filter_step
new_conv
.
weight
.
data
[
f_start
:
f_end
]
=
torch
.
index_select
(
tmp_weight_data
[
f_start
:
f_end
],
1
,
current_input_index
)
new_conv
.
weight
.
data
[
f_start
:
f_end
]
=
torch
.
index_select
(
tmp_weight_data
[
f_start
:
f_end
],
1
,
current_input_index
)
else
:
else
:
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
if
conv
.
bias
is
not
None
:
if
conv
.
bias
is
not
None
:
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
return
new_conv
return
new_conv
def
replace_convtranspose2d
(
convtrans
,
mask
):
"""
We need anothor replace function for
convtranspose2d, because the layout of
the weight is different from traditional
conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
Parameters
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert
isinstance
(
mask
,
ModuleMasks
)
assert
isinstance
(
convtrans
,
torch
.
nn
.
ConvTranspose2d
)
if
mask
.
input_mask
is
None
:
in_channels
=
convtrans
.
in_channels
else
:
in_channels_index
=
mask
.
input_mask
.
mask_index
[
1
]
in_channels
=
in_channels_index
.
size
(
0
)
if
mask
.
output_mask
is
None
:
out_channels
=
convtrans
.
out_channels
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
(
0
)
groups
=
convtrans
.
groups
# check if can remove the whole group of filters
if
convtrans
.
in_channels
==
convtrans
.
out_channels
==
convtrans
.
groups
:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert
in_channels
==
out_channels
groups
=
in_channels
_logger
.
debug
(
'Replace convtranspose2d %s with in_channels:%d out_channels:%d'
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_convtrans
=
torch
.
nn
.
ConvTranspose2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
convtrans
.
kernel_size
,
stride
=
convtrans
.
stride
,
padding
=
convtrans
.
padding
,
dilation
=
convtrans
.
dilation
,
groups
=
groups
,
bias
=
convtrans
.
bias
is
not
None
,
padding_mode
=
convtrans
.
padding_mode
)
new_convtrans
.
to
(
convtrans
.
weight
.
device
)
tmp_weight_data
=
None
if
mask
.
input_mask
is
not
None
:
# in convtranspose2d we need to select the input channel first
tmp_weight_data
=
torch
.
index_select
(
convtrans
.
weight
.
data
,
0
,
in_channels_index
)
else
:
tmp_weight_data
=
convtrans
.
weight
.
data
# we need to handle the output channel group by group like the conv layer
out_step
=
int
(
convtrans
.
out_channels
/
convtrans
.
groups
)
out_channel_group
=
int
(
out_channels
/
groups
)
new_in_per_group
=
int
(
in_channels
/
groups
)
if
mask
.
output_mask
is
not
None
and
not
(
in_channels
==
out_channels
==
groups
):
for
groupid
in
range
(
convtrans
.
groups
):
start
=
groupid
*
out_step
end
=
(
groupid
+
1
)
*
out_step
current_output_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
out_channels_index
.
tolist
()))
# we need to shift the index into the group-wise
current_output_index
=
[
x
-
start
for
x
in
current_output_index
]
if
not
current_output_index
:
# No kept channel in the current group
raise
Exception
(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily"
)
assert
len
(
current_output_index
)
==
out_channel_group
,
\
'Output channel of each group should be the same after pruning'
current_output_index
=
torch
.
tensor
(
current_output_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
new_start
=
groupid
*
new_in_per_group
new_end
=
(
groupid
+
1
)
*
new_in_per_group
new_convtrans
.
weight
.
data
[
new_start
:
new_end
]
=
torch
.
index_select
(
tmp_weight_data
[
new_start
:
new_end
],
1
,
current_output_index
)
else
:
new_convtrans
.
weight
.
data
.
copy_
(
tmp_weight_data
)
if
convtrans
.
bias
is
not
None
:
if
mask
.
output_mask
is
not
None
:
new_convtrans
.
bias
.
data
[:]
=
torch
.
index_select
(
convtrans
.
bias
.
data
,
0
,
out_channels_index
)
else
:
new_convtrans
.
bias
.
data
.
copy_
(
convtrans
.
bias
.
data
)
return
new_convtrans
nni/compression/pytorch/speedup/infer_shape.py
View file @
3ec26b40
...
@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__)
...
@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__)
conv_prune_dim
=
-
1
conv_prune_dim
=
-
1
def
set_conv_prune_dim
(
dim
):
def
set_conv_prune_dim
(
dim
):
"""
"""
Parameters:
Parameters:
...
@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim):
...
@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim):
global
conv_prune_dim
global
conv_prune_dim
conv_prune_dim
=
dim
conv_prune_dim
=
dim
class
CoarseMask
:
class
CoarseMask
:
"""
"""
Coarse grained mask for a given tensor, here tensor could be weights,
Coarse grained mask for a given tensor, here tensor could be weights,
...
@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask
...
@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask
infer_from_mask
=
{
infer_from_mask
=
{
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_mask
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_mask
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_mask
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_mask
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_mask
(
module_masks
,
mask
),
'Linear'
:
lambda
module_masks
,
mask
,
shape
:
linear_mask
(
module_masks
,
mask
,
shape
)
'Linear'
:
lambda
module_masks
,
mask
,
shape
:
linear_mask
(
module_masks
,
mask
,
shape
)
}
}
...
@@ -246,6 +249,7 @@ infer_from_inshape = {
...
@@ -246,6 +249,7 @@ infer_from_inshape = {
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape
...
@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape
"""
"""
infer_from_outshape
=
{
infer_from_outshape
=
{
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_outshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_outshape
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_outshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_outshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_outshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
...
@@ -306,6 +311,7 @@ infer_from_outshape = {
...
@@ -306,6 +311,7 @@ infer_from_outshape = {
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
}
}
def
dropout_inshape
(
module_masks
,
mask
):
def
dropout_inshape
(
module_masks
,
mask
):
if
module_masks
.
input_mask
is
None
:
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
...
@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask):
...
@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
return
module_masks
.
output_mask
def
dropout_outshape
(
module_masks
,
mask
):
def
dropout_outshape
(
module_masks
,
mask
):
if
module_masks
.
output_mask
is
None
:
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
...
@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask):
...
@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask):
return
module_masks
.
output_mask
return
module_masks
.
output_mask
def
cat_inshape
(
module_masks
,
mask
,
cat_info
,
last_visited
):
def
cat_inshape
(
module_masks
,
mask
,
cat_info
,
last_visited
):
"""
"""
Inference the output mask of the cat operation from the
Inference the output mask of the cat operation from the
...
@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask):
...
@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask):
raise
Exception
(
'Mask conflict happenes!'
)
raise
Exception
(
'Mask conflict happenes!'
)
return
None
return
None
def
add_outshape
(
module_masks
,
mask
):
def
add_outshape
(
module_masks
,
mask
):
"""
"""
Inference the input mask of the add operation from the
Inference the input mask of the add operation from the
...
@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask):
...
@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask):
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
return
mask
return
mask
else
:
else
:
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
return
mask
return
mask
def
batchnorm2d_inshape
(
module_masks
,
mask
):
def
batchnorm2d_inshape
(
module_masks
,
mask
):
"""
"""
We assume only the second dimension has coarse grained mask
We assume only the second dimension has coarse grained mask
...
@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask):
...
@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
return
mask
return
mask
def
batchnorm2d_outshape
(
module_masks
,
mask
):
def
batchnorm2d_outshape
(
module_masks
,
mask
):
"""
"""
We assume only the second dimension has coarse grained mask
We assume only the second dimension has coarse grained mask
...
@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape):
...
@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape):
module_masks
.
set_output_mask
(
output_cmask
)
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
return
output_cmask
def
view_outshape
(
module_masks
,
mask
,
shape
):
def
view_outshape
(
module_masks
,
mask
,
shape
):
"""
"""
Parameters
Parameters
...
@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape):
...
@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape):
return
input_cmask
return
input_cmask
def
size_inshape
(
module_masks
,
mask
):
def
size_inshape
(
module_masks
,
mask
):
"""
"""
No need to do anything for this ```size``` op
No need to do anything for this ```size``` op
"""
"""
return
None
return
None
def
mean_inshape
(
module_masks
,
mask
,
shape
):
def
mean_inshape
(
module_masks
,
mask
,
shape
):
"""
"""
Similar to view operation, currently mask inference only supports
Similar to view operation, currently mask inference only supports
...
@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape):
...
@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape):
module_masks
.
set_output_mask
(
output_cmask
)
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
return
output_cmask
def
mean_outshape
(
module_masks
,
mask
,
shape
):
def
mean_outshape
(
module_masks
,
mask
,
shape
):
"""
"""
Similar to view operation, currently mask inference only supports
Similar to view operation, currently mask inference only supports
...
@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape):
...
@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape):
module_masks
.
set_input_mask
(
input_cmask
)
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
return
input_cmask
def
maxpool2d_inshape
(
module_masks
,
mask
):
def
maxpool2d_inshape
(
module_masks
,
mask
):
"""
"""
Assume only the second dimension is masked
Assume only the second dimension is masked
...
@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask):
...
@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
return
mask
def
maxpool2d_outshape
(
module_masks
,
mask
):
def
maxpool2d_outshape
(
module_masks
,
mask
):
"""
"""
Assume only the second dimension is masked
Assume only the second dimension is masked
...
@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask):
...
@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
return
mask
def
relu_inshape
(
module_masks
,
mask
):
def
relu_inshape
(
module_masks
,
mask
):
"""
"""
Parameters
Parameters
...
@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask):
...
@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
return
mask
def
relu_outshape
(
module_masks
,
mask
):
def
relu_outshape
(
module_masks
,
mask
):
"""
"""
Parameters
Parameters
...
@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask):
...
@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask):
assert
isinstance
(
mask
,
CoarseMask
)
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
output_mask
is
not
None
:
if
module_masks
.
output_mask
is
not
None
:
# mask conflict should be solved before speedup
# mask conflict should be solved before speedup
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
return
mask
def
batchnorm2d_mask
(
module_masks
,
mask
):
def
batchnorm2d_mask
(
module_masks
,
mask
):
"""
"""
Infer input and output shape from weight mask
Infer input and output shape from weight mask
...
@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask):
...
@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask):
module_masks
.
set_output_mask
(
output_cmask
)
module_masks
.
set_output_mask
(
output_cmask
)
return
input_cmask
,
output_cmask
return
input_cmask
,
output_cmask
def
linear_mask
(
module_masks
,
mask
,
shape
):
def
linear_mask
(
module_masks
,
mask
,
shape
):
"""
"""
Infer input and output shape from weight mask with limitations:
Infer input and output shape from weight mask with limitations:
...
@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape):
...
@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape):
module_masks
.
set_input_mask
(
input_cmask
)
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
,
None
return
input_cmask
,
None
def
conv2d_mask
(
module_masks
,
mask
):
def
conv2d_mask
(
module_masks
,
mask
):
"""
"""
Infer input and output shape from weight mask
Infer input and output shape from weight mask
...
@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask):
...
@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask):
weight_mask
=
mask
[
'weight'
]
weight_mask
=
mask
[
'weight'
]
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
if
len
(
index
)
==
weight_mask
.
shape
[
dim
]:
# full mask
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
if
len
(
index
)
==
weight_mask
.
shape
[
dim
]:
# full mask
index
=
None
index
=
None
if
index
is
None
:
if
index
is
None
:
...
@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask):
...
@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask):
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
return
index
,
weight_cmask
,
bias_cmask
return
index
,
weight_cmask
,
bias_cmask
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
if
index
is
None
:
if
index
is
None
:
# TODO: fine grained mask speedup
# TODO: fine grained mask speedup
...
@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask):
...
@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask):
module_masks
.
set_input_mask
(
io_cmask
)
module_masks
.
set_input_mask
(
io_cmask
)
else
:
else
:
assert
module_masks
.
input_mask
==
io_cmask
assert
module_masks
.
input_mask
==
io_cmask
return
module_masks
.
input_mask
,
None
return
module_masks
.
input_mask
,
None
def
conv2d_inshape
(
module_masks
,
mask
):
def
conv2d_inshape
(
module_masks
,
mask
):
"""
"""
...
@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask):
...
@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask):
# mask conflict should be solved by fix_mask_conflict before speedup
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
...
@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask):
...
@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask):
module_masks
.
input_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
mask
return
None
return
None
def
convtranspose2d_mask
(
module_masks
,
mask
):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise
Exception
(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later"
)
def
convtranspose2d_inshape
(
module_masks
,
mask
):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
else
:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert
module_masks
.
input_mask
==
mask
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
def
convtranspose2d_outshape
(
module_masks
,
mask
):
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
if
module_masks
.
output_mask
is
None
:
module_masks
.
output_mask
=
mask
else
:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask
.
add_index_mask
(
dim
=
1
,
index
=
mask
.
mask_index
[
1
])
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
nni/compression/pytorch/utils/mask_conflict.py
View file @
3ec26b40
...
@@ -9,6 +9,7 @@ from .utils import get_module_by_name
...
@@ -9,6 +9,7 @@ from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
# logging.basicConfig(level = logging.DEBUG)
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
"""
MaskConflict fix the mask conflict for the channel dependencies
MaskConflict fix the mask conflict for the channel dependencies
...
@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
...
@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks
=
padding_cat_mask
.
fix_mask
()
masks
=
padding_cat_mask
.
fix_mask
()
return
masks
,
fix_channel_mask
.
conv_prune_dim
return
masks
,
fix_channel_mask
.
conv_prune_dim
class
MaskFix
:
class
MaskFix
:
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
# check if the parameters are valid
# check if the parameters are valid
...
@@ -74,6 +76,7 @@ class MaskFix:
...
@@ -74,6 +76,7 @@ class MaskFix:
"""
"""
torch
.
save
(
self
.
masks
,
path
)
torch
.
save
(
self
.
masks
,
path
)
class
CatMaskPadding
(
MaskFix
):
class
CatMaskPadding
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
=
None
,
traced
=
None
):
"""
"""
...
@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix):
...
@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix):
super
(
CatMaskPadding
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
super
(
CatMaskPadding
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
def
fix_mask
(
self
):
def
fix_mask
(
self
):
cat_padding_depen
=
CatPaddingDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
cat_padding_depen
=
CatPaddingDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
name_to_module
=
{}
name_to_module
=
{}
for
name
,
module
in
self
.
model
.
named_modules
():
for
name
,
module
in
self
.
model
.
named_modules
():
name_to_module
[
name
]
=
module
name_to_module
[
name
]
=
module
...
@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix):
...
@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix):
# module.bias may be None
# module.bias may be None
b_shape
=
module
.
bias
.
data
.
size
()
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
return
self
.
masks
return
self
.
masks
class
GroupMaskConflict
(
MaskFix
):
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
"""
...
@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix):
...
@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
we donnot use the model and dummpy_input to get the trace graph.
"""
"""
super
(
GroupMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
super
(
GroupMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
def
fix_mask
(
self
):
def
fix_mask
(
self
):
"""
"""
...
@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix):
...
@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix):
has group dependencies. This function should be called before the
has group dependencies. This function should be called before the
mask inference of the 'speedup' module.
mask inference of the 'speedup' module.
"""
"""
group_depen
=
GroupDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
group_depen
=
GroupDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depens
=
group_depen
.
dependency
depens
=
group_depen
.
dependency
_logger
.
info
(
depens
)
_logger
.
info
(
depens
)
for
layername
in
depens
:
for
layername
in
depens
:
...
@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix):
...
@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix):
w_mask
=
self
.
masks
[
layername
][
'weight'
]
w_mask
=
self
.
masks
[
layername
][
'weight'
]
shape
=
w_mask
.
size
()
shape
=
w_mask
.
size
()
count
=
np
.
prod
(
shape
[
1
:])
count
=
np
.
prod
(
shape
[
1
:])
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
if
len
(
all_ones
)
+
len
(
all_zeros
)
<
w_mask
.
size
(
0
):
if
len
(
all_ones
)
+
len
(
all_zeros
)
<
w_mask
.
size
(
0
):
# In fine-grained pruning, skip this layer
# In fine-grained pruning, skip this layer
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
...
@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix):
...
@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix):
for
i
in
range
(
group
):
for
i
in
range
(
group
):
_start
=
step
*
i
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
group_masked
.
append
(
_tmp_list
)
group_masked
.
append
(
_tmp_list
)
mini_masked
=
min
([
len
(
x
)
for
x
in
group_masked
])
mini_masked
=
min
([
len
(
x
)
for
x
in
group_masked
])
for
gm
in
group_masked
:
for
gm
in
group_masked
:
...
@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix):
...
@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix):
# To keep the output channel number still being divisible to
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
# groups, we set the masks of following filters to be zero.
pos
=
gm
[
i
]
pos
=
gm
[
i
]
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
shape
[
1
:])
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
if
hasattr
(
self
.
masks
[
layername
],
'bias'
):
shape
[
1
:])
if
'bias'
in
self
.
masks
[
layername
]
and
self
.
masks
[
layername
][
'bias'
]
is
not
None
:
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
return
self
.
masks
return
self
.
masks
class
ChannelMaskConflict
(
MaskFix
):
class
ChannelMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
"""
...
@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix):
...
@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
we donnot use the model and dummpy_input to get the trace graph.
"""
"""
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
_logger
.
info
(
'detected conv prune dim: %s'
,
self
.
conv_prune_dim
)
_logger
.
info
(
'detected conv prune dim: %s'
,
self
.
conv_prune_dim
)
...
@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix):
...
@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix):
are supported.
are supported.
"""
"""
if
self
.
conv_prune_dim
==
0
:
if
self
.
conv_prune_dim
==
0
:
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
else
:
else
:
channel_depen
=
InputChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
channel_depen
=
InputChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depen_sets
=
channel_depen
.
dependency_sets
depen_sets
=
channel_depen
.
dependency_sets
sum_idx
=
(
1
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
0
,
2
,
3
)
sum_idx
=
(
1
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
0
,
2
,
3
)
for
dset
in
depen_sets
:
for
dset
in
depen_sets
:
...
@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix):
...
@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix):
channel_masks
.
append
((
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
channel_masks
.
append
((
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
channel_masks
.
append
(
mask
.
int
())
channel_masks
.
append
(
mask
.
int
())
elif
type
(
m
).
__name__
==
'ConvTranspose2d'
:
# convtranspose have difference memory layout, so that we need create
# a tmp_sum_idx for conv_transpose
tmp_sum_idx
=
(
0
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
1
,
2
,
3
)
channel_mask
=
(
mask
.
abs
().
sum
(
tmp_sum_idx
)
!=
0
).
int
()
channel_masks
.
append
(
channel_mask
)
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
1
-
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
else
:
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
else
:
else
:
# no mask means not pruned, equivlent to full masks
# no mask means not pruned, equivlent to full masks
channel_masks
.
append
(
None
)
channel_masks
.
append
(
None
)
if
fine_grained
:
if
fine_grained
:
_logger
.
info
(
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
_logger
.
info
(
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
continue
continue
if
all
(
x
is
None
for
x
in
channel_masks
):
if
all
(
x
is
None
for
x
in
channel_masks
):
continue
continue
num_channels_list
=
[
len
(
x
)
for
x
in
channel_masks
if
x
is
not
None
]
num_channels_list
=
[
len
(
x
)
for
x
in
channel_masks
if
x
is
not
None
]
# number of channels in same set should be identical
# number of channels in same set should be identical
assert
len
(
set
(
num_channels_list
))
==
1
assert
len
(
set
(
num_channels_list
))
==
1
num_channels
=
num_channels_list
[
0
]
num_channels
=
num_channels_list
[
0
]
...
@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix):
...
@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix):
# merge masks with 'or'
# merge masks with 'or'
merged_channel_mask
=
channel_masks
[
0
].
clone
()
merged_channel_mask
=
channel_masks
[
0
].
clone
()
for
i
in
range
(
1
,
len
(
channel_masks
)):
for
i
in
range
(
1
,
len
(
channel_masks
)):
merged_channel_mask
=
((
merged_channel_mask
+
channel_masks
[
i
])
!=
0
).
int
()
merged_channel_mask
=
(
(
merged_channel_mask
+
channel_masks
[
i
])
!=
0
).
int
()
merged_index
=
torch
.
nonzero
(
merged_channel_mask
,
as_tuple
=
True
)[
0
]
merged_index
=
torch
.
nonzero
(
merged_channel_mask
,
as_tuple
=
True
)[
0
]
...
@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix):
...
@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix):
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
new_mask
=
merged_index
.
type_as
(
orig_mask
)
new_mask
=
merged_index
.
type_as
(
orig_mask
)
else
:
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
self
.
masks
[
name
][
'weight'
]
=
new_mask
self
.
masks
[
name
][
'weight'
]
=
new_mask
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
type
(
m
).
__name__
==
'Conv2d'
:
if
type
(
m
).
__name__
==
'Conv2d'
:
assert
self
.
conv_prune_dim
==
0
assert
self
.
conv_prune_dim
==
0
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
return
self
.
masks
return
self
.
masks
def
detect_mask_prune_dim
(
masks
,
model
):
def
detect_mask_prune_dim
(
masks
,
model
):
"""
"""
Detect how the masks of convolutional layers are pruned.
Detect how the masks of convolutional layers are pruned.
...
@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model):
...
@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model):
_logger
.
warning
(
'no multi-dimension masks found.'
)
_logger
.
warning
(
'no multi-dimension masks found.'
)
return
0
return
0
dim0_sparsity
,
dim1_sparsity
=
1.
-
dim0_preserved
/
dim0_num
,
1.
-
dim1_preserved
/
dim1_num
dim0_sparsity
,
dim1_sparsity
=
1.
-
dim0_preserved
/
\
dim0_num
,
1.
-
dim1_preserved
/
dim1_num
_logger
.
info
(
'dim0 sparsity: %f'
,
dim0_sparsity
)
_logger
.
info
(
'dim0 sparsity: %f'
,
dim0_sparsity
)
_logger
.
info
(
'dim1 sparsity: %f'
,
dim1_sparsity
)
_logger
.
info
(
'dim1 sparsity: %f'
,
dim1_sparsity
)
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
3ec26b40
...
@@ -4,13 +4,16 @@
...
@@ -4,13 +4,16 @@
import
csv
import
csv
import
logging
import
logging
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPaddingDependency'
,
'InputChannelDependency'
]
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPaddingDependency'
,
'InputChannelDependency'
]
CONV_TYPE
=
'aten::_convolution'
CONV_TYPE
=
'aten::_convolution'
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
CAT_TYPE
=
'aten::cat'
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
class
Dependency
:
class
Dependency
:
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
...
@@ -34,6 +37,7 @@ class Dependency:
...
@@ -34,6 +37,7 @@ class Dependency:
def
export
(
self
,
filepath
):
def
export
(
self
,
filepath
):
raise
NotImplementedError
raise
NotImplementedError
class
ChannelDependency
(
Dependency
):
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
"""
...
@@ -50,7 +54,8 @@ class ChannelDependency(Dependency):
...
@@ -50,7 +54,8 @@ class ChannelDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
need to trace the model again.
"""
"""
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_parent_layers
(
self
,
node
):
def
_get_parent_layers
(
self
,
node
):
"""
"""
...
@@ -71,7 +76,7 @@ class ChannelDependency(Dependency):
...
@@ -71,7 +76,7 @@ class ChannelDependency(Dependency):
queue
.
append
(
node
)
queue
.
append
(
node
)
while
queue
:
while
queue
:
curnode
=
queue
.
pop
(
0
)
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
:
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
parent_layers
.
append
(
curnode
.
name
)
continue
continue
...
@@ -119,7 +124,6 @@ class ChannelDependency(Dependency):
...
@@ -119,7 +124,6 @@ class ChannelDependency(Dependency):
for
_node
in
dependency_set
:
for
_node
in
dependency_set
:
self
.
dependency
[
_node
]
=
dependency_set
self
.
dependency
[
_node
]
=
dependency_set
def
export
(
self
,
filepath
):
def
export
(
self
,
filepath
):
"""
"""
export the channel dependencies as a csv file.
export the channel dependencies as a csv file.
...
@@ -185,6 +189,7 @@ class ChannelDependency(Dependency):
...
@@ -185,6 +189,7 @@ class ChannelDependency(Dependency):
d_sets
.
append
(
tmp_set
)
d_sets
.
append
(
tmp_set
)
return
d_sets
return
d_sets
def
reshape_break_channel_dependency
(
op_node
):
def
reshape_break_channel_dependency
(
op_node
):
"""
"""
The reshape operations such as (reshape, view, flatten) may break
The reshape operations such as (reshape, view, flatten) may break
...
@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node):
...
@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node):
out_channel
=
out_shape
[
1
]
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
return
in_channel
!=
out_channel
class
InputChannelDependency
(
ChannelDependency
):
class
InputChannelDependency
(
ChannelDependency
):
"""
"""
Some pruners may prune the input channel of the convolutional
Some pruners may prune the input channel of the convolutional
...
@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency):
...
@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency):
if we alreay has the traced graph of the target model, we donnot
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
need to trace the model again.
"""
"""
super
(
InputChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
InputChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_following_convs
(
self
,
tensor
):
def
_get_following_convs
(
self
,
tensor
):
queue
=
[]
queue
=
[]
...
@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency):
...
@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency):
queue
.
extend
(
self
.
graph
.
input_to_node
[
tensor
])
queue
.
extend
(
self
.
graph
.
input_to_node
[
tensor
])
while
queue
:
while
queue
:
curnode
=
queue
.
pop
(
0
)
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
:
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
# find the first met conv
key_layers
.
append
(
curnode
.
name
)
key_layers
.
append
(
curnode
.
name
)
continue
continue
elif
curnode
.
op_type
in
RESHAPE_OPS
:
elif
curnode
.
op_type
in
RESHAPE_OPS
:
# check if the reshape operation will break the channel dependency
# check if the reshape operation will break the channel dependency
if
reshape_break_channel_dependency
(
curnode
):
if
reshape_break_channel_dependency
(
curnode
):
# reshape operations also breaks the dependency relationship
# reshape operations also breaks the dependency relationship
continue
continue
successors
=
self
.
graph
.
find_successors
(
curnode
.
unique_name
)
successors
=
self
.
graph
.
find_successors
(
curnode
.
unique_name
)
successors
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
successors
]
successors
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
successors
]
...
@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency):
...
@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency):
class
CatPaddingDependency
(
ChannelDependency
):
class
CatPaddingDependency
(
ChannelDependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
build_dependency
(
self
):
def
build_dependency
(
self
):
"""
"""
...
@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency):
...
@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency):
row
.
extend
(
list
(
layers
))
row
.
extend
(
list
(
layers
))
csv_w
.
writerow
(
row
)
csv_w
.
writerow
(
row
)
class
GroupDependency
(
Dependency
):
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
"""
...
@@ -388,7 +397,7 @@ class GroupDependency(Dependency):
...
@@ -388,7 +397,7 @@ class GroupDependency(Dependency):
queue
=
predeessors
queue
=
predeessors
while
queue
:
while
queue
:
curnode
=
queue
.
pop
(
0
)
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
:
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
parent_layers
.
append
(
curnode
.
name
)
continue
continue
...
@@ -412,7 +421,8 @@ class GroupDependency(Dependency):
...
@@ -412,7 +421,8 @@ class GroupDependency(Dependency):
group : int
group : int
the number of the groups of the target conv layer.
the number of the groups of the target conv layer.
"""
"""
cpp_conv
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
CONV_TYPE
,
node_group
.
node_cpps
))
cpp_conv
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
CONV_TYPE
,
node_group
.
node_cpps
))
assert
len
(
cpp_conv
)
==
1
assert
len
(
cpp_conv
)
==
1
cpp_conv
=
cpp_conv
[
0
]
cpp_conv
=
cpp_conv
[
0
]
inputs
=
list
(
cpp_conv
.
inputs
())
inputs
=
list
(
cpp_conv
.
inputs
())
...
@@ -442,12 +452,14 @@ class GroupDependency(Dependency):
...
@@ -442,12 +452,14 @@ class GroupDependency(Dependency):
filters should be divisible to.
filters should be divisible to.
"""
"""
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
==
'Conv2d'
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
group
=
self
.
_get_conv_groups
(
node
)
group
=
self
.
_get_conv_groups
(
node
)
if
node
.
name
in
self
.
dependency
:
if
node
.
name
in
self
.
dependency
:
# the conv layer whose group is larger than 1 will require that
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
# it's number of output channel to be divisible by the number of group.
self
.
dependency
[
node
.
name
]
=
max
(
self
.
dependency
[
node
.
name
],
group
)
self
.
dependency
[
node
.
name
]
=
max
(
self
.
dependency
[
node
.
name
],
group
)
else
:
else
:
self
.
dependency
[
node
.
name
]
=
group
self
.
dependency
[
node
.
name
]
=
group
if
group
>
1
:
if
group
>
1
:
...
@@ -456,7 +468,8 @@ class GroupDependency(Dependency):
...
@@ -456,7 +468,8 @@ class GroupDependency(Dependency):
parent_convs
=
self
.
_get_parent_convs
(
node
)
parent_convs
=
self
.
_get_parent_convs
(
node
)
for
parent
in
parent_convs
:
for
parent
in
parent_convs
:
if
parent
in
self
.
dependency
:
if
parent
in
self
.
dependency
:
self
.
dependency
[
parent
]
=
max
(
self
.
dependency
[
parent
],
group
)
self
.
dependency
[
parent
]
=
max
(
self
.
dependency
[
parent
],
group
)
else
:
else
:
self
.
dependency
[
parent
]
=
group
self
.
dependency
[
parent
]
=
group
return
self
.
dependency
return
self
.
dependency
...
@@ -484,6 +497,7 @@ class GroupDependency(Dependency):
...
@@ -484,6 +497,7 @@ class GroupDependency(Dependency):
for
name
in
self
.
dependency
:
for
name
in
self
.
dependency
:
group
=
self
.
dependency
[
name
]
group
=
self
.
dependency
[
name
]
csv_w
.
writerow
([
name
,
group
])
csv_w
.
writerow
([
name
,
group
])
@
property
@
property
def
dependency_sets
(
self
):
def
dependency_sets
(
self
):
return
self
.
dependency
return
self
.
dependency
nni/experiment/__init__.py
View file @
3ec26b40
...
@@ -2,6 +2,6 @@
...
@@ -2,6 +2,6 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.config
import
*
from
.config
import
*
from
.experiment
import
Experiment
,
RetiariiExperiment
from
.experiment
import
Experiment
from
.nni_client
import
*
from
.nni_client
import
*
nni/experiment/config/__init__.py
View file @
3ec26b40
from
.base
import
ExperimentConfig
,
RetiariiExpConfig
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.local
import
LocalExperimentConfig
from
.common
import
*
from
.local
import
*
nni/experiment/config/base.py
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
copy
import
dataclasses
import
dataclasses
import
json
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
TypeVar
from
ruamel
import
yaml
@
dataclasses
.
dataclass
(
init
=
False
)
class
ExperimentConfig
:
from
.
import
util
experiment_name
:
str
search_space
:
Any
__all__
=
[
'ConfigBase'
,
'PathLike'
]
max_execution_seconds
:
Optional
[
int
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
T
=
TypeVar
(
'T'
,
bound
=
'ConfigBase'
)
trial_concurrency
:
int
trial_command
:
str
PathLike
=
util
.
PathLike
trial_code_directory
:
Union
[
Path
,
str
]
trial_gpu_number
:
int
=
0
def
_is_missing
(
obj
:
Any
)
->
bool
:
extra_config
:
Optional
[
Dict
[
str
,
str
]]
=
None
return
isinstance
(
obj
,
type
(
dataclasses
.
MISSING
))
_training_service
:
str
class
ConfigBase
:
"""
Base class of config classes.
# these values will be used to create template object,
Subclass may override `_canonical_rules` and `_validation_rules`,
# and the user should overwrite them later.
and `validate()` if the logic is complex.
_placeholder
=
{
"""
'experiment_name'
:
'_unset_'
,
'search_space'
:
'_unset_'
,
# Rules to convert field value to canonical format.
'trial_concurrency'
:
-
1
,
# The key is field name.
'trial_command'
:
'_unset_'
,
# The value is callable `value -> canonical_value`
'trial_code_directory'
:
'_unset_'
# It is not type-hinted so dataclass won't treat it as field
}
_canonical_rules
=
{}
# type: ignore
# simple validation functions
# Rules to validate field value.
# complex validation logic with special error message should go to `validate()` method instead
# The key is field name.
_value_range
=
{
# The value is callable `value -> valid` or `value -> (valid, error_message)`
'max_execution_seconds'
:
lambda
x
:
x
is
None
or
x
>
0
,
# The rule will be called with canonical format and is only called when `value` is not None.
'max_trial_number'
:
lambda
x
:
x
is
None
or
x
>
0
,
# `error_message` is used when `valid` is False.
'trial_concurrency'
:
lambda
x
:
x
>
0
,
# It will be prepended with class name and field name in exception message.
'trial_gpu_number'
:
lambda
x
:
x
>=
0
_validation_rules
=
{}
# type: ignore
}
def
__init__
(
self
,
*
,
_base_path
:
Optional
[
Path
]
=
None
,
**
kwargs
):
"""
def
__init__
(
self
,
**
kwargs
):
Initialize a config object and set some fields.
Name of keyword arguments can either be snake_case or camelCase.
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
kwargs
=
{
util
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
if
_base_path
is
None
:
_base_path
=
Path
()
for
field
in
dataclasses
.
fields
(
self
):
for
field
in
dataclasses
.
fields
(
self
):
if
field
.
name
in
kwargs
:
value
=
kwargs
.
pop
(
util
.
case_insensitive
(
field
.
name
),
field
.
default
)
setattr
(
self
,
field
.
name
,
kwargs
[
field
.
name
])
if
value
is
not
None
and
not
_is_missing
(
value
):
elif
field
.
default
!=
dataclasses
.
MISSING
:
# relative paths loaded from config file are not relative to pwd
setattr
(
self
,
field
.
name
,
field
.
default
)
if
'Path'
in
str
(
field
.
type
):
else
:
value
=
Path
(
value
).
expanduser
()
setattr
(
self
,
field
.
name
,
type
(
self
).
_placeholder
[
field
.
name
])
if
not
value
.
is_absolute
():
value
=
_base_path
/
value
# convert nested dict to config type
if
isinstance
(
value
,
dict
):
cls
=
util
.
strip_optional
(
field
.
type
)
if
isinstance
(
cls
,
type
)
and
issubclass
(
cls
,
ConfigBase
):
value
=
cls
(
**
value
,
_base_path
=
_base_path
)
setattr
(
self
,
field
.
name
,
value
)
if
kwargs
:
cls
=
type
(
self
).
__name__
fields
=
', '
.
join
(
kwargs
.
keys
())
raise
ValueError
(
f
'
{
cls
}
: Unrecognized fields
{
fields
}
'
)
@
classmethod
def
load
(
cls
:
Type
[
T
],
path
:
PathLike
)
->
T
:
"""
Load config from YAML (or JSON) file.
Keys in YAML file can either be camelCase or snake_case.
"""
data
=
yaml
.
safe_load
(
open
(
path
))
if
not
isinstance
(
data
,
dict
):
raise
ValueError
(
f
'Content of config file
{
path
}
is not a dict/object'
)
return
cls
(
**
data
,
_base_path
=
Path
(
path
).
parent
)
def
json
(
self
)
->
Dict
[
str
,
Any
]:
"""
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
return
dataclasses
.
asdict
(
self
.
canonical
(),
dict_factory
=
lambda
items
:
dict
((
util
.
camel_case
(
k
),
v
)
for
k
,
v
in
items
if
v
is
not
None
)
)
def
canonical
(
self
:
T
)
->
T
:
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Noticeably, relative path may be converted to absolute path.
"""
ret
=
copy
.
deepcopy
(
self
)
for
field
in
dataclasses
.
fields
(
ret
):
key
,
value
=
field
.
name
,
getattr
(
ret
,
field
.
name
)
rule
=
ret
.
_canonical_rules
.
get
(
key
)
if
rule
is
not
None
:
setattr
(
ret
,
key
,
rule
(
value
))
elif
isinstance
(
value
,
ConfigBase
):
setattr
(
ret
,
key
,
value
.
canonical
())
# value will be copied twice, should not be a performance issue anyway
return
ret
def
validate
(
self
)
->
None
:
def
validate
(
self
)
->
None
:
# check existence
"""
for
key
,
placeholder_value
in
type
(
self
).
_placeholder
.
items
():
Validate the config object and raise Exception if it's ill-formed.
if
getattr
(
self
,
key
)
==
placeholder_value
:
"""
raise
ValueError
(
f
'Field "
{
key
}
" is not set'
)
class_name
=
type
(
self
).
__name__
config
=
self
.
canonical
()
# TODO: check type
for
field
in
dataclasses
.
fields
(
config
):
# check value
key
,
value
=
field
.
name
,
getattr
(
config
,
field
.
name
)
for
key
,
condition
in
type
(
self
).
_value_range
.
items
():
value
=
getattr
(
self
,
key
)
# check existence
if
not
condition
(
value
):
if
_is_missing
(
value
):
raise
ValueError
(
f
'Field "
{
key
}
" (
{
repr
(
value
)
}
) out of range'
)
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
is not set'
)
# check special fields
# check type (TODO)
if
not
Path
(
self
.
trial_code_directory
).
is_dir
():
type_name
=
str
(
field
.
type
).
replace
(
'typing.'
,
''
)
raise
ValueError
(
f
'Trial code directory "
{
self
.
trial_code_directory
}
" does not exist or is not directory'
)
optional
=
any
([
type_name
.
startswith
(
'Optional['
),
type_name
.
startswith
(
'Union['
)
and
'NoneType'
in
type_name
,
def
experiment_config_json
(
self
)
->
Dict
[
str
,
Any
]:
type_name
==
'Any'
# this only contains the common part for most (if not all) training services
])
# subclasses should override it to provide exclusive fields
if
value
is
None
:
return
{
if
optional
:
'authorName'
:
'_'
,
continue
'experimentName'
:
self
.
experiment_name
,
else
:
'trialConcurrency'
:
self
.
trial_concurrency
,
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
cannot be None'
)
'maxExecDuration'
:
self
.
max_execution_seconds
or
(
999
*
24
*
3600
),
'maxTrialNum'
:
self
.
max_trial_number
or
99999
,
# check value
'searchSpace'
:
json
.
dumps
(
self
.
search_space
),
rule
=
config
.
_validation_rules
.
get
(
key
)
'trainingServicePlatform'
:
self
.
_training_service
,
if
rule
is
not
None
:
'tuner'
:
{
'builtinTunerName'
:
'_user_created_'
},
try
:
**
(
self
.
extra_config
or
{})
result
=
rule
(
value
)
}
except
Exception
:
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
has bad value
{
repr
(
value
)
}
'
)
def
cluster_metadata_json
(
self
)
->
Any
:
# the cluster metadata format is a total mess
if
isinstance
(
result
,
bool
):
# leave it to each subclass before we refactoring nni manager
if
not
result
:
raise
NotImplementedError
()
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
(
{
repr
(
value
)
}
) is out of range'
)
else
:
if
not
result
[
0
]:
@
staticmethod
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
{
result
[
1
]
}
'
)
def
create_template
(
training_service
:
str
)
->
'ExperimentConfig'
:
for
cls
in
ExperimentConfig
.
__subclasses__
():
# check nested config
for
field
in
dataclasses
.
fields
(
cls
):
if
isinstance
(
value
,
ConfigBase
):
if
field
.
name
==
'_training_service'
and
field
.
default
==
training_service
:
value
.
validate
()
return
cls
()
raise
ValueError
(
f
'Unrecognized training service
{
training_service
}
'
)
class
RetiariiExpConfig
(
ExperimentConfig
):
@
staticmethod
def
create_template
(
training_service
:
str
)
->
'ExperimentConfig'
:
for
cls
in
ExperimentConfig
.
__subclasses__
():
for
field
in
dataclasses
.
fields
(
cls
):
if
field
.
name
==
'_training_service'
and
field
.
default
==
training_service
:
config_obj
=
cls
()
config_obj
.
search_space
=
{}
config_obj
.
trial_command
=
'python3 -m nni.retiarii.trial_entry'
# FIXME: expose this field to users
config_obj
.
trial_code_directory
=
'../..'
return
config_obj
nni/experiment/config/common.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
.base
import
ConfigBase
,
PathLike
from
.
import
util
__all__
=
[
'ExperimentConfig'
,
'AlgorithmConfig'
,
'CustomAlgorithmConfig'
,
'TrainingServiceConfig'
,
]
@
dataclass
(
init
=
False
)
class
_AlgorithmConfig
(
ConfigBase
):
name
:
Optional
[
str
]
=
None
class_name
:
Optional
[
str
]
=
None
code_directory
:
Optional
[
PathLike
]
=
None
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
validate
(
self
):
super
().
validate
()
_validate_algo
(
self
)
@
dataclass
(
init
=
False
)
class
AlgorithmConfig
(
_AlgorithmConfig
):
name
:
str
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
(
init
=
False
)
class
CustomAlgorithmConfig
(
_AlgorithmConfig
):
class_name
:
str
class_directory
:
Optional
[
PathLike
]
=
None
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
TrainingServiceConfig
(
ConfigBase
):
platform
:
str
@
dataclass
(
init
=
False
)
class
ExperimentConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
search_space_file
:
Optional
[
PathLike
]
=
None
search_space
:
Any
=
None
trial_command
:
str
trial_code_directory
:
PathLike
=
'.'
trial_concurrency
:
int
trial_gpu_number
:
int
=
0
max_experiment_duration
:
Optional
[
str
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
use_annotation
:
bool
=
False
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
experiment_working_directory
:
Optional
[
PathLike
]
=
None
tuner_gpu_indices
:
Optional
[
Union
[
List
[
int
],
str
]]
=
None
tuner
:
Optional
[
_AlgorithmConfig
]
=
None
accessor
:
Optional
[
_AlgorithmConfig
]
=
None
advisor
:
Optional
[
_AlgorithmConfig
]
=
None
training_service
:
TrainingServiceConfig
def
__init__
(
self
,
training_service_platform
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
self
.
training_service
=
util
.
training_service_config_factory
(
training_service_platform
)
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
if
initialized_tuner
:
_validate_for_exp
(
self
)
else
:
_validate_for_nnictl
(
self
)
## End of public API ##
@
property
def
_canonical_rules
(
self
):
return
_canonical_rules
@
property
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
'search_space_file'
:
util
.
canonical_path
,
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
'experiment_working_directory'
:
util
.
canonical_path
,
'tuner_gpu_indices'
:
lambda
value
:
[
int
(
idx
)
for
idx
in
value
.
split
(
','
)]
if
isinstance
(
value
,
str
)
else
value
}
_validation_rules
=
{
'search_space_file'
:
lambda
value
:
(
Path
(
value
).
is_file
(),
f
'"
{
value
}
" does not exist or is not regular file'
),
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'max_experiment_duration'
:
lambda
value
:
util
.
parse_time
(
value
)
>
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'tuner_gpu_indices'
:
lambda
value
:
all
(
i
>=
0
for
i
in
value
)
and
len
(
value
)
==
len
(
set
(
value
)),
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
}
def
_validate_for_exp
(
config
:
ExperimentConfig
)
->
None
:
# validate experiment for nni.Experiment, where tuner is already initialized outside
if
config
.
use_annotation
:
raise
ValueError
(
'ExperimentConfig: annotation is not supported in this mode'
)
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
util
.
count
(
config
.
tuner
,
config
.
accessor
,
config
.
advisor
)
!=
0
:
raise
ValueError
(
'ExperimentConfig: tuner, accessor, and advisor must not be set in for this mode'
)
if
config
.
tuner_gpu_indices
is
not
None
:
raise
ValueError
(
'ExperimentConfig: tuner_gpu_indices is not supported in this mode'
)
def
_validate_for_nnictl
(
config
:
ExperimentConfig
)
->
None
:
# validate experiment for normal launching approach
if
config
.
use_annotation
:
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
0
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must not be set with annotationn'
)
else
:
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
util
.
count
(
config
.
tuner
,
config
.
advisor
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: tuner and advisor must be set one'
)
def
_validate_algo
(
algo
:
AlgorithmConfig
)
->
None
:
if
algo
.
name
is
None
:
if
algo
.
class_name
is
None
:
raise
ValueError
(
'Missing algorithm name'
)
if
algo
.
code_directory
is
not
None
and
not
Path
(
algo
.
code_directory
).
is_dir
():
raise
ValueError
(
f
'code_directory "
{
algo
.
code_directory
}
" does not exist or is not directory'
)
else
:
if
algo
.
class_name
is
not
None
or
algo
.
code_directory
is
not
None
:
raise
ValueError
(
f
'When name is set for registered algorithm, class_name and code_directory cannot be used'
)
# TODO: verify algorithm installation and class args
nni/experiment/config/convert.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
from
pathlib
import
Path
from
tempfile
import
NamedTemporaryFile
from
typing
import
Any
,
Dict
,
List
from
.common
import
ExperimentConfig
from
.
import
util
_logger
=
logging
.
getLogger
(
__name__
)
def
to_v1_yaml
(
config
:
ExperimentConfig
,
skip_nnictl
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
config
.
validate
(
skip_nnictl
)
data
=
config
.
json
()
ts
=
data
.
pop
(
'trainingService'
)
if
ts
[
'platform'
]
==
'openpai'
:
ts
[
'platform'
]
=
'pai'
data
[
'authorName'
]
=
'N/A'
data
[
'experimentName'
]
=
data
.
get
(
'experimentName'
,
'N/A'
)
data
[
'maxExecDuration'
]
=
data
.
pop
(
'maxExperimentDuration'
,
'999d'
)
if
data
[
'debug'
]:
data
[
'versionCheck'
]
=
False
data
[
'maxTrialNum'
]
=
data
.
pop
(
'maxTrialNumber'
,
99999
)
data
[
'trainingServicePlatform'
]
=
ts
[
'platform'
]
ss
=
data
.
pop
(
'searchSpace'
,
None
)
ss_file
=
data
.
pop
(
'searchSpaceFile'
,
None
)
if
ss
is
not
None
:
ss_file
=
NamedTemporaryFile
(
'w'
,
delete
=
False
)
json
.
dump
(
ss
,
ss_file
,
indent
=
4
)
data
[
'searchSpacePath'
]
=
ss_file
.
name
elif
ss_file
is
not
None
:
data
[
'searchSpacePath'
]
=
ss_file
if
'experimentWorkingDirectory'
in
data
:
data
[
'logDir'
]
=
data
.
pop
(
'experimentWorkingDirectory'
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
algo
=
data
.
get
(
algo_type
)
if
algo
is
None
:
continue
if
algo
[
'name'
]
is
not
None
:
# builtin
algo
[
'builtin'
+
algo_type
.
title
()
+
'Name'
]
=
algo
.
pop
(
'name'
)
algo
.
pop
(
'className'
,
None
)
algo
.
pop
(
'codeDirectory'
,
None
)
else
:
algo
.
pop
(
'name'
,
None
)
class_name_parts
=
algo
.
pop
(
'className'
).
split
(
'.'
)
algo
[
'codeDir'
]
=
algo
.
pop
(
'codeDirectory'
,
''
)
+
'/'
.
join
(
class_name_parts
[:
-
2
])
algo
[
'classFileName'
]
=
class_name_parts
[
-
2
]
+
'.py'
algo
[
'className'
]
=
class_name_parts
[
-
1
]
tuner_gpu_indices
=
_convert_gpu_indices
(
data
.
pop
(
'tunerGpuIndices'
,
None
))
if
tuner_gpu_indices
is
not
None
:
data
[
'tuner'
][
'gpuIndicies'
]
=
tuner_gpu_indices
data
[
'trial'
]
=
{
'command'
:
data
.
pop
(
'trialCommand'
),
'codeDir'
:
data
.
pop
(
'trialCodeDirectory'
),
'gpuNum'
:
data
.
pop
(
'trialGpuNumber'
,
''
)
}
if
ts
[
'platform'
]
==
'local'
:
data
[
'localConfig'
]
=
{
'useActiveGpu'
:
ts
[
'useActiveGpu'
],
'maxTrialNumPerGpu'
:
ts
[
'maxTrialNumberPerGpu'
]
}
if
ts
.
get
(
'gpuIndices'
)
is
not
None
:
data
[
'localConfig'
][
'gpuIndices'
]
=
','
.
join
(
str
(
idx
)
for
idx
in
ts
[
'gpuIndices'
])
elif
ts
[
'platform'
]
==
'remote'
:
data
[
'remoteConfig'
]
=
{
'reuse'
:
ts
[
'reuseMode'
]}
data
[
'machineList'
]
=
[]
for
machine
in
ts
[
'machineList'
]:
machine
=
{
'ip'
:
machine
[
'host'
],
'username'
:
machine
[
'user'
],
'passwd'
:
machine
[
'password'
],
'sshKeyPath'
:
machine
[
'sshKeyFile'
],
'passphrase'
:
machine
[
'sshPassphrase'
],
'gpuIndices'
:
_convert_gpu_indices
(
machine
[
'gpuIndices'
]),
'maxTrialNumPerGpu'
:
machine
[
'maxTrialNumPerGpu'
],
'useActiveGpu'
:
machine
[
'useActiveGpu'
],
'preCommand'
:
machine
[
'trialPrepareCommand'
]
}
elif
ts
[
'platform'
]
==
'pai'
:
data
[
'trial'
][
'cpuNum'
]
=
ts
[
'trialCpuNumber'
]
data
[
'trial'
][
'memoryMB'
]
=
util
.
parse_size
(
ts
[
'trialMemorySize'
])
data
[
'trial'
][
'image'
]
=
ts
[
'docker_image'
]
data
[
'paiConfig'
]
=
{
'userName'
:
ts
[
'username'
],
'token'
:
ts
[
'token'
],
'host'
:
'https://'
+
ts
[
'host'
],
'reuse'
:
ts
[
'reuseMode'
]
}
return
data
def
_convert_gpu_indices
(
indices
):
return
','
.
join
(
str
(
idx
)
for
idx
in
indices
)
if
indices
is
not
None
else
None
def
to_cluster_metadata
(
config
:
ExperimentConfig
)
->
List
[
Dict
[
str
,
Any
]]:
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
ret
=
[]
if
config
.
training_service
.
platform
==
'local'
:
request_data
=
dict
()
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
if
request_data
[
'local_config'
]:
if
request_data
[
'local_config'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'local_config'
][
'gpuIndices'
]
=
str
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
))
if
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
):
request_data
[
'local_config'
][
'maxTrialNumOnEachGpu'
]
=
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
)
if
request_data
[
'local_config'
].
get
(
'useActiveGpu'
):
request_data
[
'local_config'
][
'useActiveGpu'
]
=
request_data
[
'local_config'
].
get
(
'useActiveGpu'
)
ret
.
append
(
request_data
)
elif
config
.
training_service
.
platform
==
'remote'
:
request_data
=
dict
()
if
experiment_config
.
get
(
'remoteConfig'
):
request_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
else
:
request_data
[
'remote_config'
]
=
{
'reuse'
:
False
}
request_data
[
'machine_list'
]
=
experiment_config
[
'machineList'
]
if
request_data
[
'machine_list'
]:
for
i
in
range
(
len
(
request_data
[
'machine_list'
])):
if
isinstance
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
),
int
):
request_data
[
'machine_list'
][
i
][
'gpuIndices'
]
=
str
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
))
ret
.
append
(
request_data
)
elif
config
.
training_service
.
platform
==
'openpai'
:
pai_config_data
=
dict
()
pai_config_data
[
'pai_config'
]
=
experiment_config
[
'paiConfig'
]
ret
.
append
(
pai_config_data
)
else
:
raise
RuntimeError
(
'Unsupported training service '
+
config
.
training_service
.
platform
)
if
experiment_config
.
get
(
'nniManagerIp'
)
is
not
None
:
ret
.
append
({
'nni_manager_ip'
:
{
'nniManagerIp'
:
experiment_config
[
'nniManagerIp'
]}})
ret
.
append
({
'trial_config'
:
experiment_config
[
'trial'
]})
return
ret
def
to_rest_json
(
config
:
ExperimentConfig
)
->
Dict
[
str
,
Any
]:
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
request_data
=
dict
()
request_data
[
'authorName'
]
=
experiment_config
[
'authorName'
]
request_data
[
'experimentName'
]
=
experiment_config
[
'experimentName'
]
request_data
[
'trialConcurrency'
]
=
experiment_config
[
'trialConcurrency'
]
request_data
[
'maxExecDuration'
]
=
util
.
parse_time
(
experiment_config
[
'maxExecDuration'
])
request_data
[
'maxTrialNum'
]
=
experiment_config
[
'maxTrialNum'
]
if
config
.
search_space
is
not
None
:
request_data
[
'searchSpace'
]
=
json
.
dumps
(
config
.
search_space
)
else
:
request_data
[
'searchSpace'
]
=
Path
(
config
.
search_space_file
).
read_text
()
request_data
[
'trainingServicePlatform'
]
=
experiment_config
.
get
(
'trainingServicePlatform'
)
if
experiment_config
.
get
(
'advisor'
):
request_data
[
'advisor'
]
=
experiment_config
[
'advisor'
]
if
request_data
[
'advisor'
].
get
(
'gpuNum'
):
_logger
.
warning
(
'gpuNum is deprecated, please use gpuIndices instead.'
)
if
request_data
[
'advisor'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'advisor'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'advisor'
][
'gpuIndices'
]
=
str
(
request_data
[
'advisor'
].
get
(
'gpuIndices'
))
elif
experiment_config
.
get
(
'tuner'
):
request_data
[
'tuner'
]
=
experiment_config
[
'tuner'
]
if
request_data
[
'tuner'
].
get
(
'gpuNum'
):
_logger
.
warning
(
'gpuNum is deprecated, please use gpuIndices instead.'
)
if
request_data
[
'tuner'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'tuner'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'tuner'
][
'gpuIndices'
]
=
str
(
request_data
[
'tuner'
].
get
(
'gpuIndices'
))
if
'assessor'
in
experiment_config
:
request_data
[
'assessor'
]
=
experiment_config
[
'assessor'
]
if
request_data
[
'assessor'
].
get
(
'gpuNum'
):
_logger
.
warning
(
'gpuNum is deprecated, please remove it from your config file.'
)
else
:
request_data
[
'tuner'
]
=
{
'builtinTunerName'
:
'_user_created_'
}
#debug mode should disable version check
if
experiment_config
.
get
(
'debug'
)
is
not
None
:
request_data
[
'versionCheck'
]
=
not
experiment_config
.
get
(
'debug'
)
#validate version check
if
experiment_config
.
get
(
'versionCheck'
)
is
not
None
:
request_data
[
'versionCheck'
]
=
experiment_config
.
get
(
'versionCheck'
)
if
experiment_config
.
get
(
'logCollection'
):
request_data
[
'logCollection'
]
=
experiment_config
.
get
(
'logCollection'
)
request_data
[
'clusterMetaData'
]
=
[]
if
experiment_config
[
'trainingServicePlatform'
]
==
'local'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'codeDir'
,
'value'
:
experiment_config
[
'trial'
][
'codeDir'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'command'
,
'value'
:
experiment_config
[
'trial'
][
'command'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'remote'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'machine_list'
,
'value'
:
experiment_config
[
'machineList'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
if
not
experiment_config
.
get
(
'remoteConfig'
):
# set default value of reuse in remoteConfig to False
experiment_config
[
'remoteConfig'
]
=
{
'reuse'
:
False
}
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'remote_config'
,
'value'
:
experiment_config
[
'remoteConfig'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'pai'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
[
'paiConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'kubeflow_config'
,
'value'
:
experiment_config
[
'kubeflowConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'frameworkcontroller'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'frameworkcontroller_config'
,
'value'
:
experiment_config
[
'frameworkcontrollerConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'aml'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
return
request_data
nni/experiment/config/local.py
View file @
3ec26b40
...
@@ -2,39 +2,25 @@
...
@@ -2,39 +2,25 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
from
typing
import
Any
,
Dict
from
.
base
import
Experiment
Config
from
.
common
import
TrainingService
Config
__all__
=
[
'LocalConfig'
]
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
LocalExperimentConfig
(
ExperimentConfig
):
class
LocalConfig
(
TrainingServiceConfig
):
use_active_gpu
:
bool
=
False
platform
:
str
=
'local'
use_active_gpu
:
bool
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Optional
[
Union
[
List
[
int
],
str
]]
=
None
_training_service
:
str
=
'local'
_canonical_rules
=
{
'gpu_indices'
:
lambda
value
:
[
int
(
idx
)
for
idx
in
value
.
split
(
','
)]
if
isinstance
(
value
,
str
)
else
value
}
def
experiment_config_json
(
self
)
->
Dict
[
str
,
Any
]:
_validation_rules
=
{
ret
=
super
().
experiment_config_json
()
'platform'
:
lambda
value
:
(
value
==
'local'
,
'cannot be modified'
),
ret
[
'clusterMetaData'
]
=
[
'max_trial_number_per_gpu'
:
lambda
value
:
value
>
0
,
{
'gpu_indices'
:
lambda
value
:
all
(
idx
>=
0
for
idx
in
value
)
and
len
(
value
)
==
len
(
set
(
value
))
'key'
:
'codeDir'
,
}
'value'
:
str
(
Path
(
self
.
trial_code_directory
).
resolve
())
},
{
'key'
:
'command'
,
'value'
:
self
.
trial_command
}
]
#ret['local_config'] = {
# 'useActiveGpu': self.use_active_gpu
#}
return
ret
def
cluster_metadata_json
(
self
)
->
Any
:
return
{
'trial_config'
:
{
'command'
:
self
.
trial_command
,
'codeDir'
:
str
(
Path
(
self
.
trial_code_directory
).
resolve
())
}
}
nni/experiment/config/util.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Miscellaneous utility functions.
"""
import
math
import
os.path
from
pathlib
import
Path
from
typing
import
Optional
,
Union
PathLike
=
Union
[
Path
,
str
]
def
case_insensitive
(
key
:
str
)
->
str
:
return
key
.
lower
().
replace
(
'_'
,
''
)
def
camel_case
(
key
:
str
)
->
str
:
words
=
key
.
split
(
'_'
)
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
def
canonical_path
(
path
:
Optional
[
PathLike
])
->
Optional
[
str
]:
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
return
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
if
path
is
not
None
else
None
def
count
(
*
values
)
->
int
:
return
sum
(
value
is
not
None
and
value
is
not
False
for
value
in
values
)
def
training_service_config_factory
(
platform
:
str
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
for
cls
in
TrainingServiceConfig
.
__subclasses__
():
if
cls
.
platform
==
platform
:
return
cls
()
raise
ValueError
(
f
'Unrecognized platform
{
platform
}
'
)
def
strip_optional
(
type_hint
):
return
type_hint
.
__args__
[
0
]
if
str
(
type_hint
).
startswith
(
'typing.Optional['
)
else
type_hint
def
parse_time
(
time
:
str
,
target_unit
:
str
=
's'
)
->
int
:
return
_parse_unit
(
time
.
lower
(),
target_unit
,
_time_units
)
def
parse_size
(
size
:
str
,
target_unit
:
str
=
'mb'
)
->
int
:
return
_parse_unit
(
size
.
lower
(),
target_unit
,
_size_units
)
_time_units
=
{
'd'
:
24
*
3600
,
'h'
:
3600
,
'm'
:
60
,
's'
:
1
}
_size_units
=
{
'gb'
:
1024
*
1024
*
1024
,
'mb'
:
1024
*
1024
,
'kb'
:
1024
}
def
_parse_unit
(
string
,
target_unit
,
all_units
):
for
unit
,
factor
in
all_units
.
items
():
if
string
.
endswith
(
unit
):
number
=
string
[:
-
len
(
unit
)]
value
=
float
(
number
)
*
factor
return
math
.
ceil
(
value
/
all_units
[
target_unit
])
raise
ValueError
(
f
'Unsupported unit in "
{
string
}
"'
)
nni/experiment/experiment.py
View file @
3ec26b40
import
atexit
import
logging
import
logging
import
socket
from
subprocess
import
Popen
from
subprocess
import
Popen
import
time
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Optional
,
overload
,
List
,
Union
,
Callable
import
time
from
typing
import
Optional
,
overload
import
colorama
import
psutil
import
nni.runtime.log
from
nni.runtime.msg_dispatcher
import
MsgDispatcher
from
nni.runtime.msg_dispatcher
import
MsgDispatcher
from
nni.tuner
import
Tuner
from
nni.tuner
import
Tuner
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
from
.config
import
ExperimentConfig
from
.config
import
ExperimentConfig
from
.
import
launcher
from
.
import
launcher
from
.pipe
import
Pipe
from
.pipe
import
Pipe
from
.
import
rest
from
.
import
rest
_logger
=
logging
.
getLogger
(
__name__
)
nni
.
runtime
.
log
.
init_logger_experiment
()
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
class
Experiment
:
class
Experiment
:
"""
"""
Controls an NNI experiment.
Create and stop an NNI experiment.
You may either create a new NNI experiment with construtor and `Experiment.start()`,
# TODO: or control an existing experiment with `Experiment.connect()`.
Attributes
Attributes
----------
----------
...
@@ -42,7 +44,7 @@ class Experiment:
...
@@ -42,7 +44,7 @@ class Experiment:
Parameters
Parameters
----------
----------
tuner
tuner
A tuner instance.
# TODO: accessor / advisor
A tuner instance.
config
config
Experiment configuration.
Experiment configuration.
"""
"""
...
@@ -67,24 +69,24 @@ class Experiment:
...
@@ -67,24 +69,24 @@ class Experiment:
A tuner instance.
A tuner instance.
training_service
training_service
Name of training service.
Name of training service.
Supported value: "local", "remote", "openpai"
/"pai"
.
Supported value: "local", "remote", "openpai".
"""
"""
...
...
def
__init__
(
self
,
tuner
:
Tuner
,
config
=
None
,
training_service
=
None
):
def
__init__
(
self
,
tuner
:
Tuner
,
config
=
None
,
training_service
=
None
):
self
.
config
:
ExperimentConfig
self
.
config
:
ExperimentConfig
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
_dispatcher
=
MsgDispatcher
(
tuner
,
None
)
self
.
tuner
:
Tuner
=
tuner
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_dispatcher
:
Optional
[
MsgDispatcher
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
if
isinstance
(
config
,
str
):
if
isinstance
(
config
,
str
):
config
,
training_service
=
None
,
config
config
,
training_service
=
None
,
config
if
training_service
==
'openpai'
:
training_service
=
'pai'
if
config
is
None
:
if
config
is
None
:
self
.
config
=
ExperimentConfig
.
create_template
(
training_service
)
self
.
config
=
ExperimentConfig
(
training_service
)
else
:
else
:
self
.
config
=
config
self
.
config
=
config
...
@@ -103,6 +105,8 @@ class Experiment:
...
@@ -103,6 +105,8 @@ class Experiment:
debug
debug
Whether to start in debug mode.
Whether to start in debug mode.
"""
"""
atexit
.
register
(
self
.
stop
)
if
debug
:
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
...
@@ -112,9 +116,20 @@ class Experiment:
...
@@ -112,9 +116,20 @@ class Experiment:
self
.
port
=
port
# port will be None if start up failed
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be
creat
ed after pipe initialized
# dispatcher must be
launch
ed after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
self
.
_dispatcher
=
MsgDispatcher
(
self
.
tuner
,
None
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
.
start
()
ips
=
[
self
.
config
.
nni_manager_ip
]
for
interfaces
in
psutil
.
net_if_addrs
().
values
():
for
interface
in
interfaces
:
if
interface
.
family
==
socket
.
AF_INET
:
ips
.
append
(
interface
.
address
)
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
_logger
.
info
(
msg
)
# TODO: register experiment management metadata
# TODO: register experiment management metadata
...
@@ -123,27 +138,41 @@ class Experiment:
...
@@ -123,27 +138,41 @@ class Experiment:
"""
"""
Stop background experiment.
Stop background experiment.
"""
"""
self
.
_proc
.
kill
()
_logger
.
info
(
'Stopping experiment...'
)
self
.
_pipe
.
close
()
atexit
.
unregister
(
self
.
stop
)
if
self
.
_proc
is
not
None
:
self
.
_proc
.
kill
()
if
self
.
_pipe
is
not
None
:
self
.
_pipe
.
close
()
if
self
.
_dispatcher_thread
is
not
None
:
self
.
_dispatcher
.
stopping
=
True
self
.
_dispatcher_thread
.
join
(
timeout
=
1
)
self
.
port
=
None
self
.
port
=
None
self
.
_proc
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
self
.
_pipe
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher_thread
=
None
def
run
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
bool
:
"""
"""
Run the experiment.
Run the experiment.
This function will block until experiment finish or error.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
"""
self
.
start
(
port
,
debug
)
self
.
start
(
port
,
debug
)
try
:
try
:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
status
=
self
.
get_status
()
status
=
self
.
get_status
()
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
if
status
==
'STOPPED'
:
return
status
return
True
if
status
==
'ERROR'
:
return
False
finally
:
finally
:
self
.
stop
()
self
.
stop
()
...
@@ -153,97 +182,3 @@ class Experiment:
...
@@ -153,97 +182,3 @@ class Experiment:
raise
RuntimeError
(
'Experiment is not running'
)
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
return
resp
[
'status'
]
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
'nn.Module'
,
trainer
:
'BaseTrainer'
,
applied_mutators
:
List
[
'Mutator'
],
strategy
:
'BaseStrategy'
,
tca
:
'TraceClassArguments'
=
None
):
self
.
config
:
ExperimentConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
trainer
=
trainer
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
recorded_module_args
=
tca
.
recorded_arguments
# FIXME: remove this argument
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
def
_start_strategy
(
self
):
import
torch
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
,
trainer_config
)).
start
()
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
config
:
ExperimentConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
self
.
_start_strategy
()
# TODO: register experiment management metadata
def
stop
(
self
)
->
None
:
"""
Stop background experiment.
"""
self
.
_proc
.
kill
()
self
.
_pipe
.
close
()
self
.
port
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
def
run
(
self
,
config
:
ExperimentConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self
.
start
(
config
,
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
finally
:
self
.
stop
()
def
get_status
(
self
)
->
str
:
if
self
.
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
nni/experiment/launcher.py
View file @
3ec26b40
import
contextlib
import
contextlib
import
logging
from
pathlib
import
Path
from
pathlib
import
Path
import
socket
import
socket
from
subprocess
import
Popen
from
subprocess
import
Popen
...
@@ -6,40 +7,46 @@ import sys
...
@@ -6,40 +7,46 @@ import sys
import
time
import
time
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
import
colorama
import
nni.runtime.protocol
import
nni.runtime.protocol
import
nni_node
import
nni_node
from
.config
import
ExperimentConfig
from
.config
import
ExperimentConfig
from
.config
import
convert
from
.
import
management
from
.
import
management
from
.pipe
import
Pipe
from
.pipe
import
Pipe
from
.
import
rest
from
.
import
rest
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
def
start_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Tuple
[
Popen
,
Pipe
]:
def
start_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Tuple
[
Popen
,
Pipe
]:
pipe
=
None
pipe
=
None
proc
=
None
proc
=
None
config
.
validate
()
config
.
validate
(
initialized_tuner
=
True
)
_ensure_port_idle
(
port
)
_ensure_port_idle
(
port
)
if
config
.
_
training_service
==
'pai'
:
if
config
.
training_service
.
platform
==
'
open
pai'
:
_ensure_port_idle
(
port
+
1
,
'OpenPAI requires an additional port'
)
_ensure_port_idle
(
port
+
1
,
'OpenPAI requires an additional port'
)
exp_id
=
management
.
generate_experiment_id
()
exp_id
=
management
.
generate_experiment_id
()
try
:
try
:
print
(
f
'Creating experiment
{
exp_id
}
...
'
)
_logger
.
info
(
f
'Creating experiment
{
colorama
.
Fore
.
CYAN
}
{
exp_id
}
'
)
pipe
=
Pipe
(
exp_id
)
pipe
=
Pipe
(
exp_id
)
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
_logger
.
info
(
'Connecting IPC pipe...'
)
pipe_file
=
pipe
.
connect
()
pipe_file
=
pipe
.
connect
()
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_out_file
=
pipe_file
nni
.
runtime
.
protocol
.
_out_file
=
pipe_file
print
(
'Statring web server...'
)
_logger
.
info
(
'Statring web server...'
)
_check_rest_server
(
port
)
_check_rest_server
(
port
)
print
(
'Setting up...'
)
_logger
.
info
(
'Setting up...'
)
_init_experiment
(
config
,
port
,
debug
)
# todo: kill on fail
_init_experiment
(
config
,
port
,
debug
)
return
proc
,
pipe
return
proc
,
pipe
except
Exception
as
e
:
except
Exception
as
e
:
print
(
'Create experiment failed'
)
_logger
.
error
(
'Create experiment failed'
)
if
proc
is
not
None
:
if
proc
is
not
None
:
with
contextlib
.
suppress
(
Exception
):
with
contextlib
.
suppress
(
Exception
):
proc
.
kill
()
proc
.
kill
()
...
@@ -58,9 +65,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
...
@@ -58,9 +65,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
def
_start_rest_server
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
pipe_path
:
str
)
->
Popen
:
def
_start_rest_server
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
pipe_path
:
str
)
->
Popen
:
ts
=
config
.
training_service
.
platform
if
ts
==
'openpai'
:
ts
=
'pai'
args
=
{
args
=
{
'port'
:
port
,
'port'
:
port
,
'mode'
:
config
.
_training_service
,
'mode'
:
ts
,
'experiment_id'
:
experiment_id
,
'experiment_id'
:
experiment_id
,
'start_mode'
:
'new'
,
'start_mode'
:
'new'
,
'log_level'
:
'debug'
if
debug
else
'info'
,
'log_level'
:
'debug'
if
debug
else
'info'
,
...
@@ -77,15 +88,18 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
...
@@ -77,15 +88,18 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
return
Popen
(
cmd
,
cwd
=
node_dir
)
return
Popen
(
cmd
,
cwd
=
node_dir
)
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
10
)
->
None
:
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
3
)
->
None
:
for
_
in
range
(
retry
):
for
i
in
range
(
retry
):
with
contextlib
.
suppress
(
Exception
):
with
contextlib
.
suppress
(
Exception
):
rest
.
get
(
port
,
'/check-status'
)
rest
.
get
(
port
,
'/check-status'
)
return
return
if
i
>
0
:
_logger
.
warning
(
'Timeout, retry...'
)
time
.
sleep
(
1
)
time
.
sleep
(
1
)
rest
.
get
(
port
,
'/check-status'
)
rest
.
get
(
port
,
'/check-status'
)
def
_init_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
None
:
def
_init_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
None
:
rest
.
put
(
port
,
'/experiment/cluster-metadata'
,
config
.
cluster_metadata_json
())
for
cluster_metadata
in
convert
.
to_cluster_metadata
(
config
):
rest
.
post
(
port
,
'/experiment'
,
config
.
experiment_config_json
())
rest
.
put
(
port
,
'/experiment/cluster-metadata'
,
cluster_metadata
)
rest
.
post
(
port
,
'/experiment'
,
convert
.
to_rest_json
(
config
))
nni/experiment/nni_client.py
View file @
3ec26b40
...
@@ -26,7 +26,6 @@ import subprocess
...
@@ -26,7 +26,6 @@ import subprocess
import
re
import
re
import
json
import
json
import
requests
import
requests
import
yaml
__all__
=
[
__all__
=
[
'ExternalExperiment'
,
'ExternalExperiment'
,
...
@@ -260,38 +259,6 @@ class ExternalExperiment:
...
@@ -260,38 +259,6 @@ class ExternalExperiment:
self
.
_endpoint
=
'http://localhost:{}'
.
format
(
self
.
_port
)
self
.
_endpoint
=
'http://localhost:{}'
.
format
(
self
.
_port
)
self
.
_exp_id
=
self
.
get_experiment_profile
()[
'id'
]
self
.
_exp_id
=
self
.
get_experiment_profile
()[
'id'
]
def
tmp_start_retiarii
(
self
,
graph_ir
,
training_approach
,
applied_mutators
,
strategy
,
exp_config
):
# prepare search space file which includes base graph IR and mutators
search_space
=
{}
search_space
[
'base_model_ir'
]
=
graph_ir
search_space
[
'applied_mutators'
]
=
applied_mutators
search_space
[
'training_approach'
]
=
training_approach
with
open
(
'search_space.json'
,
'w'
)
as
f
:
json
.
dump
(
search_space
,
f
)
# add advisor config to exp_config
exp_config
[
'searchSpacePath'
]
=
'search_space.json'
exp_config
[
'useAnnotation'
]
=
False
exp_config
[
'advisor'
]
=
{
'codeDir'
:
'.'
,
'classFileName'
:
'advisor_entry.py'
,
'className'
:
'RetiariiAdvisor'
,
'classArgs'
:
{
'strategy'
:
'{}.{}'
.
format
(
strategy
[
'filename'
],
strategy
[
'funcname'
])
}
}
# add trial config to exp_config
exp_config
[
'trial'
]
=
{
'command'
:
'python3 -m nni.retiarii.trial_entry'
,
'codeDir'
:
'../..'
,
'gpuNum'
:
0
}
# dump exp_config to nni.yml
with
open
(
'nni.yml'
,
'w'
)
as
f
:
yaml
.
dump
(
exp_config
,
f
)
# start experiment
self
.
start_experiment
(
'nni.yml'
)
def
start_experiment
(
self
,
config_file
,
port
=
None
,
debug
=
False
):
def
start_experiment
(
self
,
config_file
,
port
=
None
,
debug
=
False
):
"""
"""
Start an experiment with specified configuration file and connect to it.
Start an experiment with specified configuration file and connect to it.
...
...
nni/experiment/pipe.py
View file @
3ec26b40
...
@@ -3,7 +3,7 @@ import os
...
@@ -3,7 +3,7 @@ import os
import
sys
import
sys
if
sys
.
platform
==
'win32'
:
if
sys
.
platform
==
'win32'
:
import
_win
32
import
_win
api
import
msvcrt
import
msvcrt
class
WindowsPipe
:
class
WindowsPipe
:
...
@@ -11,27 +11,27 @@ if sys.platform == 'win32':
...
@@ -11,27 +11,27 @@ if sys.platform == 'win32':
self
.
path
:
str
=
r
'\\.\pipe\nni-'
+
experiment_id
self
.
path
:
str
=
r
'\\.\pipe\nni-'
+
experiment_id
self
.
file
=
None
self
.
file
=
None
self
.
_handle
=
_win
32
.
CreateNamedPipe
(
self
.
_handle
=
_win
api
.
CreateNamedPipe
(
self
.
path
,
self
.
path
,
_win
32
.
PIPE_ACCESS_DUPLEX
,
_win
api
.
PIPE_ACCESS_DUPLEX
,
_win
32
.
PIPE_TYPE_MESSAGE
|
_win
32
.
PIPE_READMODE_MESSAGE
|
_win
32
.
PIPE_WAIT
,
_win
api
.
PIPE_TYPE_MESSAGE
|
_win
api
.
PIPE_READMODE_MESSAGE
|
_win
api
.
PIPE_WAIT
,
1
,
1
,
8192
,
8192
,
8192
,
8192
,
0
,
0
,
_win
32
.
NULL
_win
api
.
NULL
)
)
def
connect
(
self
)
->
BufferedIOBase
:
def
connect
(
self
)
->
BufferedIOBase
:
_win
32
.
ConnectNamedPipe
(
self
.
_handle
,
_win
32
.
NULL
)
_win
api
.
ConnectNamedPipe
(
self
.
_handle
,
_win
api
.
NULL
)
fd
=
msvcrt
.
open_osfhandle
(
self
.
_handle
)
fd
=
msvcrt
.
open_osfhandle
(
self
.
_handle
,
0
)
self
.
file
=
os
.
fdopen
(
fd
,
'
r
wb'
)
self
.
file
=
os
.
fdopen
(
fd
,
'w
+
b'
)
return
self
.
file
return
self
.
file
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
if
self
.
file
is
not
None
:
if
self
.
file
is
not
None
:
self
.
file
.
close
()
self
.
file
.
close
()
_win
32
.
CloseHandle
(
self
.
_handle
)
_win
api
.
CloseHandle
(
self
.
_handle
)
Pipe
=
WindowsPipe
Pipe
=
WindowsPipe
...
@@ -52,7 +52,7 @@ else:
...
@@ -52,7 +52,7 @@ else:
def
connect
(
self
)
->
BufferedIOBase
:
def
connect
(
self
)
->
BufferedIOBase
:
conn
,
_
=
self
.
_socket
.
accept
()
conn
,
_
=
self
.
_socket
.
accept
()
self
.
file
=
conn
.
makefile
(
'
r
wb'
)
self
.
file
=
conn
.
makefile
(
'w
+
b'
)
return
self
.
file
return
self
.
file
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
...
...
nni/runtime/log.py
View file @
3ec26b40
...
@@ -4,8 +4,11 @@ import logging
...
@@ -4,8 +4,11 @@ import logging
from
logging
import
FileHandler
,
Formatter
,
Handler
,
StreamHandler
from
logging
import
FileHandler
,
Formatter
,
Handler
,
StreamHandler
from
pathlib
import
Path
from
pathlib
import
Path
import
sys
import
sys
import
time
from
typing
import
Optional
from
typing
import
Optional
import
colorama
from
.env_vars
import
dispatcher_env_vars
,
trial_env_vars
from
.env_vars
import
dispatcher_env_vars
,
trial_env_vars
...
@@ -17,6 +20,8 @@ def init_logger() -> None:
...
@@ -17,6 +20,8 @@ def init_logger() -> None:
The detection should work in most cases but for `nnictl` and `nni.experiment`.
The detection should work in most cases but for `nnictl` and `nni.experiment`.
They will be identified as "standalone" mode and must configure the logger by themselves.
They will be identified as "standalone" mode and must configure the logger by themselves.
"""
"""
colorama
.
init
()
if
dispatcher_env_vars
.
SDK_PROCESS
==
'dispatcher'
:
if
dispatcher_env_vars
.
SDK_PROCESS
==
'dispatcher'
:
_init_logger_dispatcher
()
_init_logger_dispatcher
()
return
return
...
@@ -33,6 +38,15 @@ def init_logger() -> None:
...
@@ -33,6 +38,15 @@ def init_logger() -> None:
_init_logger_standalone
()
_init_logger_standalone
()
def
init_logger_experiment
()
->
None
:
"""
Initialize logger for `nni.experiment.Experiment`.
This function will get invoked after `init_logger()`.
"""
formatter
.
format
=
_colorful_format
time_format
=
'%Y-%m-%d %H:%M:%S'
time_format
=
'%Y-%m-%d %H:%M:%S'
formatter
=
Formatter
(
formatter
=
Formatter
(
...
@@ -40,14 +54,14 @@ formatter = Formatter(
...
@@ -40,14 +54,14 @@ formatter = Formatter(
time_format
time_format
)
)
def
_init_logger_dispatcher
()
->
None
:
def
_init_logger_dispatcher
()
->
None
:
log_level_map
=
{
log_level_map
=
{
'fatal'
:
logging
.
CRITICAL
,
'fatal'
:
logging
.
CRITICAL
,
'error'
:
logging
.
ERROR
,
'error'
:
logging
.
ERROR
,
'warning'
:
logging
.
WARNING
,
'warning'
:
logging
.
WARNING
,
'info'
:
logging
.
INFO
,
'info'
:
logging
.
INFO
,
'debug'
:
logging
.
DEBUG
'debug'
:
logging
.
DEBUG
,
'trace'
:
0
}
}
log_path
=
_prepare_log_dir
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
)
/
'dispatcher.log'
log_path
=
_prepare_log_dir
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
)
/
'dispatcher.log'
...
@@ -93,6 +107,21 @@ def _setup_logger(name: str, handler: Handler, level: int) -> None:
...
@@ -93,6 +107,21 @@ def _setup_logger(name: str, handler: Handler, level: int) -> None:
logger
.
setLevel
(
level
)
logger
.
setLevel
(
level
)
logger
.
propagate
=
False
logger
.
propagate
=
False
def
_colorful_format
(
record
):
if
record
.
levelno
>=
logging
.
ERROR
:
color
=
colorama
.
Fore
.
RED
elif
record
.
levelno
>=
logging
.
WARNING
:
color
=
colorama
.
Fore
.
YELLOW
elif
record
.
levelno
>=
logging
.
INFO
:
color
=
colorama
.
Fore
.
GREEN
else
:
color
=
colorama
.
Fore
.
BLUE
msg
=
color
+
(
record
.
msg
%
record
.
args
)
+
colorama
.
Style
.
RESET_ALL
time
=
formatter
.
formatTime
(
record
,
time_format
)
if
record
.
levelno
<
logging
.
INFO
:
return
'[{}] {}:{} {}'
.
format
(
time
,
record
.
threadName
,
record
.
name
,
msg
)
else
:
return
'[{}] {}'
.
format
(
time
,
msg
)
class
_LogFileWrapper
(
TextIOBase
):
class
_LogFileWrapper
(
TextIOBase
):
# wrap the logger file so that anything written to it will automatically get formatted
# wrap the logger file so that anything written to it will automatically get formatted
...
...
nni/runtime/msg_dispatcher_base.py
View file @
3ec26b40
...
@@ -25,11 +25,11 @@ class MsgDispatcherBase(Recoverable):
...
@@ -25,11 +25,11 @@ class MsgDispatcherBase(Recoverable):
"""
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
stopping
=
False
if
multi_thread_enabled
():
if
multi_thread_enabled
():
self
.
pool
=
ThreadPool
()
self
.
pool
=
ThreadPool
()
self
.
thread_results
=
[]
self
.
thread_results
=
[]
else
:
else
:
self
.
stopping
=
False
self
.
default_command_queue
=
Queue
()
self
.
default_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
...
@@ -43,11 +43,11 @@ class MsgDispatcherBase(Recoverable):
...
@@ -43,11 +43,11 @@ class MsgDispatcherBase(Recoverable):
"""Run the tuner.
"""Run the tuner.
This function will never return unless raise.
This function will never return unless raise.
"""
"""
_logger
.
info
(
'
Start d
ispatcher'
)
_logger
.
info
(
'
D
ispatcher
started
'
)
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
self
.
load_checkpoint
()
self
.
load_checkpoint
()
while
True
:
while
not
self
.
stopping
:
command
,
data
=
receive
()
command
,
data
=
receive
()
if
data
:
if
data
:
data
=
json_tricks
.
loads
(
data
)
data
=
json_tricks
.
loads
(
data
)
...
@@ -75,7 +75,7 @@ class MsgDispatcherBase(Recoverable):
...
@@ -75,7 +75,7 @@ class MsgDispatcherBase(Recoverable):
self
.
default_worker
.
join
()
self
.
default_worker
.
join
()
self
.
assessor_worker
.
join
()
self
.
assessor_worker
.
join
()
_logger
.
info
(
'
T
erminated
by NNI manager
'
)
_logger
.
info
(
'
Dispatcher t
ermin
i
ated'
)
def
command_queue_worker
(
self
,
command_queue
):
def
command_queue_worker
(
self
,
command_queue
):
"""Process commands in command queues.
"""Process commands in command queues.
...
...
Prev
1
…
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment