Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
3ec26b40
"include/ck/utility/static_buffer.hpp" did not exist on "78b987fbd6a7897ee9827187a231441794b13490"
Unverified
Commit
3ec26b40
authored
Dec 11, 2020
by
liuzhe-lz
Committed by
GitHub
Dec 11, 2020
Browse files
Merge master into dev-retiarii (#3178)
parent
d165905d
Changes
327
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1112 additions
and
383 deletions
+1112
-383
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+62
-11
nni/compression/pytorch/pruning/__init__.py
nni/compression/pytorch/pruning/__init__.py
+4
-0
nni/compression/pytorch/pruning/apply_compression.py
nni/compression/pytorch/pruning/apply_compression.py
+29
-0
nni/compression/pytorch/speedup/compress_modules.py
nni/compression/pytorch/speedup/compress_modules.py
+118
-11
nni/compression/pytorch/speedup/infer_shape.py
nni/compression/pytorch/speedup/infer_shape.py
+106
-7
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+49
-22
nni/compression/pytorch/utils/shape_dependency.py
nni/compression/pytorch/utils/shape_dependency.py
+28
-14
nni/experiment/__init__.py
nni/experiment/__init__.py
+1
-1
nni/experiment/config/__init__.py
nni/experiment/config/__init__.py
+4
-2
nni/experiment/config/base.py
nni/experiment/config/base.py
+145
-107
nni/experiment/config/common.py
nni/experiment/config/common.py
+145
-0
nni/experiment/config/convert.py
nni/experiment/config/convert.py
+228
-0
nni/experiment/config/local.py
nni/experiment/config/local.py
+16
-30
nni/experiment/config/util.py
nni/experiment/config/util.py
+54
-0
nni/experiment/experiment.py
nni/experiment/experiment.py
+52
-117
nni/experiment/launcher.py
nni/experiment/launcher.py
+26
-12
nni/experiment/nni_client.py
nni/experiment/nni_client.py
+0
-33
nni/experiment/pipe.py
nni/experiment/pipe.py
+10
-10
nni/runtime/log.py
nni/runtime/log.py
+31
-2
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+4
-4
No files found.
nni/compression/pytorch/compressor.py
View file @
3ec26b40
...
...
@@ -580,17 +580,55 @@ class QuantType:
"""
Enum class for quantization type.
"""
QUANT_INPUT
=
0
QUANT_WEIGHT
=
1
QUANT_OUTPUT
=
2
QUANT_INPUT
=
'input'
QUANT_WEIGHT
=
'weight'
QUANT_OUTPUT
=
'output'
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
Base class for overriding backward function of quantization operation.
"""
@
classmethod
def
_quantize
(
cls
,
x
,
scale
,
zero_point
):
"""
Reference function for quantizing x -- non-clamped.
Parameters
----------
x : Tensor
tensor to be quantized
scale : Tensor
scale for quantizing x
zero_point : Tensor
zero_point for quantizing x
Returns
-------
tensor
quantized x without clamped
"""
return
((
x
/
scale
)
+
zero_point
).
round
()
@
classmethod
def
get_bits_length
(
cls
,
config
,
quant_type
):
"""
Get bit for quantize config
Parameters
----------
config : Dict
the configuration for quantization
quant_type : str
quant type
Returns
-------
int
n-bits for quantization configuration
"""
if
isinstance
(
config
[
"quant_bits"
],
int
):
return
config
[
"quant_bits"
]
else
:
return
config
[
"quant_bits"
].
get
(
quant_type
)
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
):
def
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
):
"""
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
...
...
@@ -600,32 +638,45 @@ class QuantGrad(torch.autograd.Function):
input of quantization operation
grad_output : Tensor
gradient of the output of quantization operation
quant_type : QuantType
scale : Tensor
the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`,
you can define different behavior for different types.
zero_point : Tensor
zero_point for quantizing tensor
qmin : Tensor
quant_min for quantizing tensor
qmax : Tensor
quant_max for quantizng tensor
Returns
-------
tensor
gradient of the input of quantization operation
"""
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
@
staticmethod
def
forward
(
ctx
,
tensor
,
quant_type
,
wrapper
,
**
kwargs
):
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]))
if
quant_type
==
QuantType
.
QUANT_INPUT
:
return
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_input
(
tensor
,
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_WEIGHT
:
return
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_weight
(
wrapper
,
**
kwargs
)
elif
quant_type
==
QuantType
.
QUANT_OUTPUT
:
return
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
output
=
wrapper
.
quantizer
.
quantize_output
(
tensor
,
wrapper
,
**
kwargs
)
else
:
raise
ValueError
(
"unrecognized QuantType."
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
torch
.
Tensor
([
0
],
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
],
device
=
tensor
.
device
)
ctx
.
save_for_backward
(
tensor
,
wrapper
.
module
.
scale
,
wrapper
.
module
.
zero_point
,
qmin
,
qmax
)
return
output
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
quant_type
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
)
tensor
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
def
_check_weight
(
module
):
...
...
nni/compression/pytorch/pruning/__init__.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.apply_compression
import
apply_compression_results
nni/compression/pytorch/pruning/apply_compression.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
torch
logger
=
logging
.
getLogger
(
'torch apply compression'
)
def
apply_compression_results
(
model
,
masks_file
,
map_location
=
None
):
"""
Apply the masks from ```masks_file``` to the model
Note: this API is for inference, because it simply multiplies weights with
corresponding masks when this API is called.
Parameters
----------
model : torch.nn.Module
The model to be compressed
masks_file : str
The path of the mask file
map_location : str
the device on which masks are placed, same to map_location in ```torch.load```
"""
masks
=
torch
.
load
(
masks_file
,
map_location
)
for
name
,
module
in
model
.
named_modules
():
if
name
in
masks
:
module
.
weight
.
data
=
module
.
weight
.
data
.
mul_
(
masks
[
name
][
'weight'
])
if
hasattr
(
module
,
'bias'
)
and
module
.
bias
is
not
None
and
'bias'
in
masks
[
name
]:
module
.
bias
.
data
=
module
.
bias
.
data
.
mul_
(
masks
[
name
][
'bias'
])
\ No newline at end of file
nni/compression/pytorch/speedup/compress_modules.py
View file @
3ec26b40
...
...
@@ -10,6 +10,7 @@ _logger = logging.getLogger(__name__)
replace_module
=
{
'BatchNorm2d'
:
lambda
module
,
mask
:
replace_batchnorm2d
(
module
,
mask
),
'Conv2d'
:
lambda
module
,
mask
:
replace_conv2d
(
module
,
mask
),
'ConvTranspose2d'
:
lambda
module
,
mask
:
replace_convtranspose2d
(
module
,
mask
),
'MaxPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
'AdaptiveAvgPool2d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
),
...
...
@@ -22,6 +23,7 @@ replace_module = {
'Dropout3d'
:
lambda
module
,
mask
:
no_replace
(
module
,
mask
)
}
def
no_replace
(
module
,
mask
):
"""
No need to replace
...
...
@@ -29,6 +31,7 @@ def no_replace(module, mask):
_logger
.
debug
(
"no need to replace"
)
return
module
def
replace_linear
(
linear
,
mask
):
"""
Parameters
...
...
@@ -54,11 +57,13 @@ def replace_linear(linear, mask):
out_features
=
linear
.
out_features
,
bias
=
linear
.
bias
is
not
None
)
new_linear
.
to
(
linear
.
weight
.
device
)
new_linear
.
weight
.
data
=
torch
.
index_select
(
linear
.
weight
.
data
,
-
1
,
index
.
to
(
linear
.
weight
.
device
))
new_linear
.
weight
.
data
=
torch
.
index_select
(
linear
.
weight
.
data
,
-
1
,
index
.
to
(
linear
.
weight
.
device
))
if
linear
.
bias
is
not
None
:
new_linear
.
bias
.
data
.
copy_
(
linear
.
bias
.
data
)
return
new_linear
def
replace_batchnorm2d
(
norm
,
mask
):
"""
Parameters
...
...
@@ -87,10 +92,13 @@ def replace_batchnorm2d(norm, mask):
new_norm
.
weight
.
data
=
torch
.
index_select
(
norm
.
weight
.
data
,
0
,
index
)
new_norm
.
bias
.
data
=
torch
.
index_select
(
norm
.
bias
.
data
,
0
,
index
)
if
norm
.
track_running_stats
:
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
index
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
index
)
new_norm
.
running_mean
.
data
=
torch
.
index_select
(
norm
.
running_mean
.
data
,
0
,
index
)
new_norm
.
running_var
.
data
=
torch
.
index_select
(
norm
.
running_var
.
data
,
0
,
index
)
return
new_norm
def
replace_conv2d
(
conv
,
mask
):
"""
Parameters
...
...
@@ -121,7 +129,8 @@ def replace_conv2d(conv, mask):
# remove groups for depthwise layers
assert
in_channels
==
out_channels
groups
=
in_channels
_logger
.
debug
(
"replace conv2d %s with in_channels: %d, out_channels: %d"
,
mask
.
module_name
,
in_channels
,
out_channels
)
_logger
.
debug
(
"replace conv2d %s with in_channels: %d, out_channels: %d"
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_conv
=
torch
.
nn
.
Conv2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
conv
.
kernel_size
,
...
...
@@ -136,9 +145,11 @@ def replace_conv2d(conv, mask):
tmp_weight_data
=
tmp_bias_data
=
None
if
mask
.
output_mask
is
not
None
:
tmp_weight_data
=
torch
.
index_select
(
conv
.
weight
.
data
,
0
,
out_channels_index
)
tmp_weight_data
=
torch
.
index_select
(
conv
.
weight
.
data
,
0
,
out_channels_index
)
if
conv
.
bias
is
not
None
:
tmp_bias_data
=
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
out_channels_index
)
tmp_bias_data
=
torch
.
index_select
(
conv
.
bias
.
data
,
0
,
out_channels_index
)
else
:
tmp_weight_data
=
conv
.
weight
.
data
# For the convolutional layers that have more than one group
...
...
@@ -152,24 +163,120 @@ def replace_conv2d(conv, mask):
for
groupid
in
range
(
conv
.
groups
):
start
=
groupid
*
input_step
end
=
(
groupid
+
1
)
*
input_step
current_input_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
in_channels_index
.
tolist
()))
current_input_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
in_channels_index
.
tolist
()))
if
not
current_input_index
:
# there is no kept channel in current group
continue
# TODO bug here, the groups is directly get from conv.groups, if the whole group is removed,
# then the number of groups in the new_conv also need to change
raise
Exception
(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily"
)
# shift the global index into the group index
current_input_index
=
[
x
-
start
for
x
in
current_input_index
]
# if the groups is larger than 1, the input channels of each
# group should be pruned evenly.
assert
len
(
current_input_index
)
==
in_channels_group
,
\
'Input channels of each group are not pruned evenly'
current_input_index
=
torch
.
tensor
(
current_input_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
current_input_index
=
torch
.
tensor
(
current_input_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
f_start
=
groupid
*
filter_step
f_end
=
(
groupid
+
1
)
*
filter_step
new_conv
.
weight
.
data
[
f_start
:
f_end
]
=
torch
.
index_select
(
tmp_weight_data
[
f_start
:
f_end
],
1
,
current_input_index
)
new_conv
.
weight
.
data
[
f_start
:
f_end
]
=
torch
.
index_select
(
tmp_weight_data
[
f_start
:
f_end
],
1
,
current_input_index
)
else
:
new_conv
.
weight
.
data
.
copy_
(
tmp_weight_data
)
if
conv
.
bias
is
not
None
:
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
new_conv
.
bias
.
data
.
copy_
(
conv
.
bias
.
data
if
tmp_bias_data
is
None
else
tmp_bias_data
)
return
new_conv
def
replace_convtranspose2d
(
convtrans
,
mask
):
"""
We need anothor replace function for
convtranspose2d, because the layout of
the weight is different from traditional
conv layers. The layout of the weight is [N_in, N_out, ksize_1, ksize_2]
Parameters
----------
convtrans : torch.nn.ConvTranspose2d
The conv2d module to be replaced
mask : ModuleMasks
The masks of this module
Returns
-------
torch.nn.ConvTranspose2d
The new conv2d module
"""
assert
isinstance
(
mask
,
ModuleMasks
)
assert
isinstance
(
convtrans
,
torch
.
nn
.
ConvTranspose2d
)
if
mask
.
input_mask
is
None
:
in_channels
=
convtrans
.
in_channels
else
:
in_channels_index
=
mask
.
input_mask
.
mask_index
[
1
]
in_channels
=
in_channels_index
.
size
(
0
)
if
mask
.
output_mask
is
None
:
out_channels
=
convtrans
.
out_channels
else
:
out_channels_index
=
mask
.
output_mask
.
mask_index
[
1
]
out_channels
=
out_channels_index
.
size
(
0
)
groups
=
convtrans
.
groups
# check if can remove the whole group of filters
if
convtrans
.
in_channels
==
convtrans
.
out_channels
==
convtrans
.
groups
:
# remove groups for depthwise layers
# this needs the group dependency to be fixed before the speedup
assert
in_channels
==
out_channels
groups
=
in_channels
_logger
.
debug
(
'Replace convtranspose2d %s with in_channels:%d out_channels:%d'
,
mask
.
module_name
,
in_channels
,
out_channels
)
new_convtrans
=
torch
.
nn
.
ConvTranspose2d
(
in_channels
=
in_channels
,
out_channels
=
out_channels
,
kernel_size
=
convtrans
.
kernel_size
,
stride
=
convtrans
.
stride
,
padding
=
convtrans
.
padding
,
dilation
=
convtrans
.
dilation
,
groups
=
groups
,
bias
=
convtrans
.
bias
is
not
None
,
padding_mode
=
convtrans
.
padding_mode
)
new_convtrans
.
to
(
convtrans
.
weight
.
device
)
tmp_weight_data
=
None
if
mask
.
input_mask
is
not
None
:
# in convtranspose2d we need to select the input channel first
tmp_weight_data
=
torch
.
index_select
(
convtrans
.
weight
.
data
,
0
,
in_channels_index
)
else
:
tmp_weight_data
=
convtrans
.
weight
.
data
# we need to handle the output channel group by group like the conv layer
out_step
=
int
(
convtrans
.
out_channels
/
convtrans
.
groups
)
out_channel_group
=
int
(
out_channels
/
groups
)
new_in_per_group
=
int
(
in_channels
/
groups
)
if
mask
.
output_mask
is
not
None
and
not
(
in_channels
==
out_channels
==
groups
):
for
groupid
in
range
(
convtrans
.
groups
):
start
=
groupid
*
out_step
end
=
(
groupid
+
1
)
*
out_step
current_output_index
=
list
(
filter
(
lambda
x
:
start
<=
x
and
x
<
end
,
out_channels_index
.
tolist
()))
# we need to shift the index into the group-wise
current_output_index
=
[
x
-
start
for
x
in
current_output_index
]
if
not
current_output_index
:
# No kept channel in the current group
raise
Exception
(
" Donnot support removing the whole group filter except in the depth-wise conv temporarily"
)
assert
len
(
current_output_index
)
==
out_channel_group
,
\
'Output channel of each group should be the same after pruning'
current_output_index
=
torch
.
tensor
(
current_output_index
).
to
(
tmp_weight_data
.
device
)
# pylint: disable=not-callable
new_start
=
groupid
*
new_in_per_group
new_end
=
(
groupid
+
1
)
*
new_in_per_group
new_convtrans
.
weight
.
data
[
new_start
:
new_end
]
=
torch
.
index_select
(
tmp_weight_data
[
new_start
:
new_end
],
1
,
current_output_index
)
else
:
new_convtrans
.
weight
.
data
.
copy_
(
tmp_weight_data
)
if
convtrans
.
bias
is
not
None
:
if
mask
.
output_mask
is
not
None
:
new_convtrans
.
bias
.
data
[:]
=
torch
.
index_select
(
convtrans
.
bias
.
data
,
0
,
out_channels_index
)
else
:
new_convtrans
.
bias
.
data
.
copy_
(
convtrans
.
bias
.
data
)
return
new_convtrans
nni/compression/pytorch/speedup/infer_shape.py
View file @
3ec26b40
...
...
@@ -13,6 +13,7 @@ _logger = logging.getLogger(__name__)
conv_prune_dim
=
-
1
def
set_conv_prune_dim
(
dim
):
"""
Parameters:
...
...
@@ -23,6 +24,7 @@ def set_conv_prune_dim(dim):
global
conv_prune_dim
conv_prune_dim
=
dim
class
CoarseMask
:
"""
Coarse grained mask for a given tensor, here tensor could be weights,
...
...
@@ -228,6 +230,7 @@ Infer input and output shape of a module/function from its weight mask
infer_from_mask
=
{
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_mask
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_mask
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_mask
(
module_masks
,
mask
),
'Linear'
:
lambda
module_masks
,
mask
,
shape
:
linear_mask
(
module_masks
,
mask
,
shape
)
}
...
...
@@ -246,6 +249,7 @@ infer_from_inshape = {
'aten::relu_'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'aten::sigmoid'
:
lambda
module_masks
,
mask
:
relu_inshape
(
module_masks
,
mask
),
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_inshape
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_inshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::max_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
'aten::avg_pool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_inshape
(
module_masks
,
mask
),
...
...
@@ -277,6 +281,7 @@ Infer input and weight shape of a module/function from its output shape
"""
infer_from_outshape
=
{
'Conv2d'
:
lambda
module_masks
,
mask
:
conv2d_outshape
(
module_masks
,
mask
),
'ConvTranspose2d'
:
lambda
module_masks
,
mask
:
convtranspose2d_outshape
(
module_masks
,
mask
),
'BatchNorm2d'
:
lambda
module_masks
,
mask
:
batchnorm2d_outshape
(
module_masks
,
mask
),
'MaxPool2d'
:
lambda
module_masks
,
mask
:
maxpool2d_outshape
(
module_masks
,
mask
),
...
...
@@ -306,6 +311,7 @@ infer_from_outshape = {
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
}
def
dropout_inshape
(
module_masks
,
mask
):
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
...
...
@@ -325,6 +331,7 @@ def dropout_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
return
module_masks
.
output_mask
def
dropout_outshape
(
module_masks
,
mask
):
if
module_masks
.
output_mask
is
None
:
module_masks
.
set_output_mask
(
mask
)
...
...
@@ -335,6 +342,7 @@ def dropout_outshape(module_masks, mask):
return
module_masks
.
output_mask
def
cat_inshape
(
module_masks
,
mask
,
cat_info
,
last_visited
):
"""
Inference the output mask of the cat operation from the
...
...
@@ -433,6 +441,7 @@ def add_inshape(module_masks, mask):
raise
Exception
(
'Mask conflict happenes!'
)
return
None
def
add_outshape
(
module_masks
,
mask
):
"""
Inference the input mask of the add operation from the
...
...
@@ -445,9 +454,11 @@ def add_outshape(module_masks, mask):
module_masks
.
set_input_mask
(
mask
)
return
mask
else
:
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
return
mask
def
batchnorm2d_inshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
...
...
@@ -477,6 +488,7 @@ def batchnorm2d_inshape(module_masks, mask):
module_masks
.
set_param_masks
(
'bias'
,
weight_cmask
)
return
mask
def
batchnorm2d_outshape
(
module_masks
,
mask
):
"""
We assume only the second dimension has coarse grained mask
...
...
@@ -577,6 +589,7 @@ def view_inshape(module_masks, mask, shape):
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
def
view_outshape
(
module_masks
,
mask
,
shape
):
"""
Parameters
...
...
@@ -614,12 +627,14 @@ def view_outshape(module_masks, mask, shape):
return
input_cmask
def
size_inshape
(
module_masks
,
mask
):
"""
No need to do anything for this ```size``` op
"""
return
None
def
mean_inshape
(
module_masks
,
mask
,
shape
):
"""
Similar to view operation, currently mask inference only supports
...
...
@@ -642,6 +657,7 @@ def mean_inshape(module_masks, mask, shape):
module_masks
.
set_output_mask
(
output_cmask
)
return
output_cmask
def
mean_outshape
(
module_masks
,
mask
,
shape
):
"""
Similar to view operation, currently mask inference only supports
...
...
@@ -662,6 +678,7 @@ def mean_outshape(module_masks, mask, shape):
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
def
maxpool2d_inshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
...
...
@@ -690,6 +707,7 @@ def maxpool2d_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
return
mask
def
maxpool2d_outshape
(
module_masks
,
mask
):
"""
Assume only the second dimension is masked
...
...
@@ -714,6 +732,7 @@ def maxpool2d_outshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
return
mask
def
relu_inshape
(
module_masks
,
mask
):
"""
Parameters
...
...
@@ -737,6 +756,7 @@ def relu_inshape(module_masks, mask):
module_masks
.
set_output_mask
(
mask
)
return
mask
def
relu_outshape
(
module_masks
,
mask
):
"""
Parameters
...
...
@@ -754,11 +774,13 @@ def relu_outshape(module_masks, mask):
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
output_mask
is
not
None
:
# mask conflict should be solved before speedup
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
module_masks
.
set_input_mask
(
mask
)
module_masks
.
set_output_mask
(
mask
)
return
mask
def
batchnorm2d_mask
(
module_masks
,
mask
):
"""
Infer input and output shape from weight mask
...
...
@@ -792,6 +814,7 @@ def batchnorm2d_mask(module_masks, mask):
module_masks
.
set_output_mask
(
output_cmask
)
return
input_cmask
,
output_cmask
def
linear_mask
(
module_masks
,
mask
,
shape
):
"""
Infer input and output shape from weight mask with limitations:
...
...
@@ -825,6 +848,7 @@ def linear_mask(module_masks, mask, shape):
module_masks
.
set_input_mask
(
input_cmask
)
return
input_cmask
,
None
def
conv2d_mask
(
module_masks
,
mask
):
"""
Infer input and output shape from weight mask
...
...
@@ -863,8 +887,9 @@ def conv2d_mask(module_masks, mask):
weight_mask
=
mask
[
'weight'
]
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
if
len
(
index
)
==
weight_mask
.
shape
[
dim
]:
# full mask
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
if
len
(
index
)
==
weight_mask
.
shape
[
dim
]:
# full mask
index
=
None
if
index
is
None
:
...
...
@@ -882,7 +907,8 @@ def conv2d_mask(module_masks, mask):
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
return
index
,
weight_cmask
,
bias_cmask
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
if
index
is
None
:
# TODO: fine grained mask speedup
...
...
@@ -910,7 +936,8 @@ def conv2d_mask(module_masks, mask):
module_masks
.
set_input_mask
(
io_cmask
)
else
:
assert
module_masks
.
input_mask
==
io_cmask
return
module_masks
.
input_mask
,
None
return
module_masks
.
input_mask
,
None
def
conv2d_inshape
(
module_masks
,
mask
):
"""
...
...
@@ -972,7 +999,8 @@ def conv2d_outshape(module_masks, mask):
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
weight_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
...
...
@@ -988,3 +1016,74 @@ def conv2d_outshape(module_masks, mask):
module_masks
.
input_mask
=
mask
return
mask
return
None
def
convtranspose2d_mask
(
module_masks
,
mask
):
# TODO support the Convtranspose2d Pruning for the L1FilterPruner
raise
Exception
(
"Current Filter pruner cannot prune the ConvTranspose2d, will support pruning ConvTranspose2d later"
)
def
convtranspose2d_inshape
(
module_masks
,
mask
):
"""
Shape change of input tensor does not affect the shape of its output tensor
Parameters
----------
module_masks : ModuleMasks
The ModuleMasks instance of the conv2d
mask : CoarseMask
The mask of its input tensor
Returns
-------
CoarseMask
The mask of its output tensor
"""
assert
isinstance
(
mask
,
CoarseMask
)
if
module_masks
.
input_mask
is
None
:
module_masks
.
set_input_mask
(
mask
)
else
:
# the same conv layer may be accessed more
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
assert
module_masks
.
input_mask
==
mask
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
def
convtranspose2d_outshape
(
module_masks
,
mask
):
assert
isinstance
(
mask
,
CoarseMask
)
assert
mask
.
mask_index
[
1
]
is
not
None
assert
mask
.
mask_index
[
0
]
is
None
assert
mask
.
mask_index
[
2
]
is
None
assert
mask
.
mask_index
[
3
]
is
None
if
module_masks
.
output_mask
is
None
:
module_masks
.
output_mask
=
mask
else
:
# mask conflict should be solved by fix_mask_conflict before speedup
# mask and module_masks.output_mask may have different number of dimensions
# since they could be passed by linear or conv2d
assert
all
(
module_masks
.
output_mask
.
mask_index
[
1
]
==
mask
.
mask_index
[
1
])
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
# Note the memory layout of Convtranspose2d is C_in, C_out, k1, k2
weight_cmask
.
add_index_mask
(
dim
=
1
,
index
=
mask
.
mask_index
[
1
])
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
mask
.
mask_index
[
1
])
module_masks
.
set_param_masks
(
'weight'
,
weight_cmask
)
module_masks
.
set_param_masks
(
'bias'
,
bias_cmask
)
# shape changes pass through depths wise conv layers
m
=
module_masks
.
module
if
m
.
in_channels
==
m
.
out_channels
==
m
.
groups
:
module_masks
.
output_mask
=
mask
module_masks
.
input_mask
=
mask
return
mask
return
None
nni/compression/pytorch/utils/mask_conflict.py
View file @
3ec26b40
...
...
@@ -9,6 +9,7 @@ from .utils import get_module_by_name
# logging.basicConfig(level = logging.DEBUG)
_logger
=
logging
.
getLogger
(
__name__
)
def
fix_mask_conflict
(
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
MaskConflict fix the mask conflict for the channel dependencies
...
...
@@ -50,6 +51,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
masks
=
padding_cat_mask
.
fix_mask
()
return
masks
,
fix_channel_mask
.
conv_prune_dim
class
MaskFix
:
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
# check if the parameters are valid
...
...
@@ -74,6 +76,7 @@ class MaskFix:
"""
torch
.
save
(
self
.
masks
,
path
)
class
CatMaskPadding
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
,
dummy_input
=
None
,
traced
=
None
):
"""
...
...
@@ -100,7 +103,8 @@ class CatMaskPadding(MaskFix):
super
(
CatMaskPadding
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
def
fix_mask
(
self
):
cat_padding_depen
=
CatPaddingDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
cat_padding_depen
=
CatPaddingDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
name_to_module
=
{}
for
name
,
module
in
self
.
model
.
named_modules
():
name_to_module
[
name
]
=
module
...
...
@@ -131,11 +135,10 @@ class CatMaskPadding(MaskFix):
# module.bias may be None
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
return
self
.
masks
class
GroupMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
...
...
@@ -154,8 +157,8 @@ class GroupMaskConflict(MaskFix):
the traced model of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super
(
GroupMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
super
(
GroupMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
def
fix_mask
(
self
):
"""
...
...
@@ -163,7 +166,8 @@ class GroupMaskConflict(MaskFix):
has group dependencies. This function should be called before the
mask inference of the 'speedup' module.
"""
group_depen
=
GroupDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
group_depen
=
GroupDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depens
=
group_depen
.
dependency
_logger
.
info
(
depens
)
for
layername
in
depens
:
...
...
@@ -174,8 +178,10 @@ class GroupMaskConflict(MaskFix):
w_mask
=
self
.
masks
[
layername
][
'weight'
]
shape
=
w_mask
.
size
()
count
=
np
.
prod
(
shape
[
1
:])
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
all_ones
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
count
).
nonzero
().
squeeze
(
1
).
tolist
()
all_zeros
=
(
w_mask
.
flatten
(
1
).
sum
(
-
1
)
==
0
).
nonzero
().
squeeze
(
1
).
tolist
()
if
len
(
all_ones
)
+
len
(
all_zeros
)
<
w_mask
.
size
(
0
):
# In fine-grained pruning, skip this layer
_logger
.
info
(
'Layers %s using fine-grained pruning'
,
layername
)
...
...
@@ -190,7 +196,8 @@ class GroupMaskConflict(MaskFix):
for
i
in
range
(
group
):
_start
=
step
*
i
_end
=
step
*
(
i
+
1
)
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
_tmp_list
=
list
(
filter
(
lambda
x
:
_start
<=
x
and
x
<
_end
,
all_zeros
))
group_masked
.
append
(
_tmp_list
)
mini_masked
=
min
([
len
(
x
)
for
x
in
group_masked
])
for
gm
in
group_masked
:
...
...
@@ -198,13 +205,13 @@ class GroupMaskConflict(MaskFix):
# To keep the output channel number still being divisible to
# groups, we set the masks of following filters to be zero.
pos
=
gm
[
i
]
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
shape
[
1
:])
if
hasattr
(
self
.
masks
[
layername
],
'bias'
):
self
.
masks
[
layername
][
'weight'
][
pos
]
=
torch
.
ones
(
shape
[
1
:])
if
'bias'
in
self
.
masks
[
layername
]
and
self
.
masks
[
layername
][
'bias'
]
is
not
None
:
self
.
masks
[
layername
][
'bias'
][
pos
]
=
1
return
self
.
masks
class
ChannelMaskConflict
(
MaskFix
):
def
__init__
(
self
,
masks
,
model
=
None
,
dummy_input
=
None
,
traced
=
None
):
"""
...
...
@@ -223,7 +230,8 @@ class ChannelMaskConflict(MaskFix):
the traced graph of the target model, is this parameter is not None,
we donnot use the model and dummpy_input to get the trace graph.
"""
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
super
(
ChannelMaskConflict
,
self
).
__init__
(
masks
,
model
,
dummy_input
,
traced
)
self
.
conv_prune_dim
=
detect_mask_prune_dim
(
masks
,
model
)
_logger
.
info
(
'detected conv prune dim: %s'
,
self
.
conv_prune_dim
)
...
...
@@ -235,9 +243,11 @@ class ChannelMaskConflict(MaskFix):
are supported.
"""
if
self
.
conv_prune_dim
==
0
:
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
channel_depen
=
ChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
else
:
channel_depen
=
InputChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
channel_depen
=
InputChannelDependency
(
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depen_sets
=
channel_depen
.
dependency_sets
sum_idx
=
(
1
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
0
,
2
,
3
)
for
dset
in
depen_sets
:
...
...
@@ -262,17 +272,29 @@ class ChannelMaskConflict(MaskFix):
channel_masks
.
append
((
mask
.
abs
().
sum
(
0
)
!=
0
).
int
())
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
channel_masks
.
append
(
mask
.
int
())
elif
type
(
m
).
__name__
==
'ConvTranspose2d'
:
# convtranspose have difference memory layout, so that we need create
# a tmp_sum_idx for conv_transpose
tmp_sum_idx
=
(
0
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
1
,
2
,
3
)
channel_mask
=
(
mask
.
abs
().
sum
(
tmp_sum_idx
)
!=
0
).
int
()
channel_masks
.
append
(
channel_mask
)
if
(
channel_mask
.
sum
()
*
(
mask
.
numel
()
/
mask
.
shape
[
1
-
self
.
conv_prune_dim
])).
item
()
!=
(
mask
>
0
).
sum
().
item
():
fine_grained
=
True
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
else
:
# no mask means not pruned, equivlent to full masks
channel_masks
.
append
(
None
)
if
fine_grained
:
_logger
.
info
(
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
_logger
.
info
(
'fine-grained mask detected, skip solving conflict for this set: %s'
,
dset
)
continue
if
all
(
x
is
None
for
x
in
channel_masks
):
continue
num_channels_list
=
[
len
(
x
)
for
x
in
channel_masks
if
x
is
not
None
]
num_channels_list
=
[
len
(
x
)
for
x
in
channel_masks
if
x
is
not
None
]
# number of channels in same set should be identical
assert
len
(
set
(
num_channels_list
))
==
1
num_channels
=
num_channels_list
[
0
]
...
...
@@ -284,7 +306,8 @@ class ChannelMaskConflict(MaskFix):
# merge masks with 'or'
merged_channel_mask
=
channel_masks
[
0
].
clone
()
for
i
in
range
(
1
,
len
(
channel_masks
)):
merged_channel_mask
=
((
merged_channel_mask
+
channel_masks
[
i
])
!=
0
).
int
()
merged_channel_mask
=
(
(
merged_channel_mask
+
channel_masks
[
i
])
!=
0
).
int
()
merged_index
=
torch
.
nonzero
(
merged_channel_mask
,
as_tuple
=
True
)[
0
]
...
...
@@ -305,16 +328,19 @@ class ChannelMaskConflict(MaskFix):
elif
type
(
m
).
__name__
==
'BatchNorm2d'
:
new_mask
=
merged_index
.
type_as
(
orig_mask
)
else
:
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
raise
RuntimeError
(
f
'unsupported module type:
{
type
(
m
).
__name__
}
'
)
self
.
masks
[
name
][
'weight'
]
=
new_mask
if
'bias'
in
self
.
masks
[
name
]
and
self
.
masks
[
name
][
'bias'
]
is
not
None
:
if
type
(
m
).
__name__
==
'Conv2d'
:
assert
self
.
conv_prune_dim
==
0
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
self
.
masks
[
name
][
'bias'
]
=
merged_channel_mask
.
type_as
(
self
.
masks
[
name
][
'bias'
])
return
self
.
masks
def
detect_mask_prune_dim
(
masks
,
model
):
"""
Detect how the masks of convolutional layers are pruned.
...
...
@@ -358,7 +384,8 @@ def detect_mask_prune_dim(masks, model):
_logger
.
warning
(
'no multi-dimension masks found.'
)
return
0
dim0_sparsity
,
dim1_sparsity
=
1.
-
dim0_preserved
/
dim0_num
,
1.
-
dim1_preserved
/
dim1_num
dim0_sparsity
,
dim1_sparsity
=
1.
-
dim0_preserved
/
\
dim0_num
,
1.
-
dim1_preserved
/
dim1_num
_logger
.
info
(
'dim0 sparsity: %f'
,
dim0_sparsity
)
_logger
.
info
(
'dim1 sparsity: %f'
,
dim1_sparsity
)
...
...
nni/compression/pytorch/utils/shape_dependency.py
View file @
3ec26b40
...
...
@@ -4,13 +4,16 @@
import
csv
import
logging
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPaddingDependency'
,
'InputChannelDependency'
]
__all__
=
[
'ChannelDependency'
,
'GroupDependency'
,
'CatPaddingDependency'
,
'InputChannelDependency'
]
CONV_TYPE
=
'aten::_convolution'
ADD_TYPES
=
[
'aten::add'
,
'aten::add_'
]
CAT_TYPE
=
'aten::cat'
logger
=
logging
.
getLogger
(
'Shape_Dependency'
)
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
RESHAPE_OPS
=
[
CAT_TYPE
,
'aten::view'
,
'aten::reshape'
,
'aten::flatten'
,
'aten::mean'
]
class
Dependency
:
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
...
...
@@ -34,6 +37,7 @@ class Dependency:
def
export
(
self
,
filepath
):
raise
NotImplementedError
class
ChannelDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
...
...
@@ -50,7 +54,8 @@ class ChannelDependency(Dependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
ChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_parent_layers
(
self
,
node
):
"""
...
...
@@ -71,7 +76,7 @@ class ChannelDependency(Dependency):
queue
.
append
(
node
)
while
queue
:
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
:
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
continue
...
...
@@ -119,7 +124,6 @@ class ChannelDependency(Dependency):
for
_node
in
dependency_set
:
self
.
dependency
[
_node
]
=
dependency_set
def
export
(
self
,
filepath
):
"""
export the channel dependencies as a csv file.
...
...
@@ -185,6 +189,7 @@ class ChannelDependency(Dependency):
d_sets
.
append
(
tmp_set
)
return
d_sets
def
reshape_break_channel_dependency
(
op_node
):
"""
The reshape operations such as (reshape, view, flatten) may break
...
...
@@ -213,6 +218,7 @@ def reshape_break_channel_dependency(op_node):
out_channel
=
out_shape
[
1
]
return
in_channel
!=
out_channel
class
InputChannelDependency
(
ChannelDependency
):
"""
Some pruners may prune the input channel of the convolutional
...
...
@@ -242,7 +248,8 @@ class InputChannelDependency(ChannelDependency):
if we alreay has the traced graph of the target model, we donnot
need to trace the model again.
"""
super
(
InputChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
InputChannelDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
_get_following_convs
(
self
,
tensor
):
queue
=
[]
...
...
@@ -250,14 +257,14 @@ class InputChannelDependency(ChannelDependency):
queue
.
extend
(
self
.
graph
.
input_to_node
[
tensor
])
while
queue
:
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
:
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'Linear'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
key_layers
.
append
(
curnode
.
name
)
continue
elif
curnode
.
op_type
in
RESHAPE_OPS
:
# check if the reshape operation will break the channel dependency
if
reshape_break_channel_dependency
(
curnode
):
# reshape operations also breaks the dependency relationship
# reshape operations also breaks the dependency relationship
continue
successors
=
self
.
graph
.
find_successors
(
curnode
.
unique_name
)
successors
=
[
self
.
graph
.
name_to_node
[
name
]
for
name
in
successors
]
...
...
@@ -290,7 +297,8 @@ class InputChannelDependency(ChannelDependency):
class
CatPaddingDependency
(
ChannelDependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
super
(
CatPaddingDependency
,
self
).
__init__
(
model
,
dummy_input
,
traced_model
)
def
build_dependency
(
self
):
"""
...
...
@@ -347,6 +355,7 @@ class CatPaddingDependency(ChannelDependency):
row
.
extend
(
list
(
layers
))
csv_w
.
writerow
(
row
)
class
GroupDependency
(
Dependency
):
def
__init__
(
self
,
model
=
None
,
dummy_input
=
None
,
traced_model
=
None
):
"""
...
...
@@ -388,7 +397,7 @@ class GroupDependency(Dependency):
queue
=
predeessors
while
queue
:
curnode
=
queue
.
pop
(
0
)
if
curnode
.
op_type
==
'Conv2d'
:
if
curnode
.
op_type
==
'Conv2d'
or
curnode
.
op_type
==
'ConvTranspose2d'
:
# find the first met conv
parent_layers
.
append
(
curnode
.
name
)
continue
...
...
@@ -412,7 +421,8 @@ class GroupDependency(Dependency):
group : int
the number of the groups of the target conv layer.
"""
cpp_conv
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
CONV_TYPE
,
node_group
.
node_cpps
))
cpp_conv
=
list
(
filter
(
lambda
x
:
x
.
kind
()
==
CONV_TYPE
,
node_group
.
node_cpps
))
assert
len
(
cpp_conv
)
==
1
cpp_conv
=
cpp_conv
[
0
]
inputs
=
list
(
cpp_conv
.
inputs
())
...
...
@@ -442,12 +452,14 @@ class GroupDependency(Dependency):
filters should be divisible to.
"""
for
node
in
self
.
graph
.
nodes_py
.
nodes_op
:
if
node
.
op_type
==
'Conv2d'
:
if
node
.
op_type
==
'Conv2d'
or
node
.
op_type
==
'ConvTranspose2d'
:
group
=
self
.
_get_conv_groups
(
node
)
if
node
.
name
in
self
.
dependency
:
# the conv layer whose group is larger than 1 will require that
# it's number of output channel to be divisible by the number of group.
self
.
dependency
[
node
.
name
]
=
max
(
self
.
dependency
[
node
.
name
],
group
)
self
.
dependency
[
node
.
name
]
=
max
(
self
.
dependency
[
node
.
name
],
group
)
else
:
self
.
dependency
[
node
.
name
]
=
group
if
group
>
1
:
...
...
@@ -456,7 +468,8 @@ class GroupDependency(Dependency):
parent_convs
=
self
.
_get_parent_convs
(
node
)
for
parent
in
parent_convs
:
if
parent
in
self
.
dependency
:
self
.
dependency
[
parent
]
=
max
(
self
.
dependency
[
parent
],
group
)
self
.
dependency
[
parent
]
=
max
(
self
.
dependency
[
parent
],
group
)
else
:
self
.
dependency
[
parent
]
=
group
return
self
.
dependency
...
...
@@ -484,6 +497,7 @@ class GroupDependency(Dependency):
for
name
in
self
.
dependency
:
group
=
self
.
dependency
[
name
]
csv_w
.
writerow
([
name
,
group
])
@
property
def
dependency_sets
(
self
):
return
self
.
dependency
nni/experiment/__init__.py
View file @
3ec26b40
...
...
@@ -2,6 +2,6 @@
# Licensed under the MIT license.
from
.config
import
*
from
.experiment
import
Experiment
,
RetiariiExperiment
from
.experiment
import
Experiment
from
.nni_client
import
*
nni/experiment/config/__init__.py
View file @
3ec26b40
from
.base
import
ExperimentConfig
,
RetiariiExpConfig
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.local
import
LocalExperimentConfig
from
.common
import
*
from
.local
import
*
nni/experiment/config/base.py
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
dataclasses
import
json
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
@
dataclasses
.
dataclass
(
init
=
False
)
class
ExperimentConfig
:
experiment_name
:
str
search_space
:
Any
max_execution_seconds
:
Optional
[
int
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
trial_concurrency
:
int
trial_command
:
str
trial_code_directory
:
Union
[
Path
,
str
]
trial_gpu_number
:
int
=
0
extra_config
:
Optional
[
Dict
[
str
,
str
]]
=
None
_training_service
:
str
# these values will be used to create template object,
# and the user should overwrite them later.
_placeholder
=
{
'experiment_name'
:
'_unset_'
,
'search_space'
:
'_unset_'
,
'trial_concurrency'
:
-
1
,
'trial_command'
:
'_unset_'
,
'trial_code_directory'
:
'_unset_'
}
# simple validation functions
# complex validation logic with special error message should go to `validate()` method instead
_value_range
=
{
'max_execution_seconds'
:
lambda
x
:
x
is
None
or
x
>
0
,
'max_trial_number'
:
lambda
x
:
x
is
None
or
x
>
0
,
'trial_concurrency'
:
lambda
x
:
x
>
0
,
'trial_gpu_number'
:
lambda
x
:
x
>=
0
}
def
__init__
(
self
,
**
kwargs
):
from
typing
import
Any
,
Dict
,
Optional
,
Type
,
TypeVar
from
ruamel
import
yaml
from
.
import
util
__all__
=
[
'ConfigBase'
,
'PathLike'
]
T
=
TypeVar
(
'T'
,
bound
=
'ConfigBase'
)
PathLike
=
util
.
PathLike
def
_is_missing
(
obj
:
Any
)
->
bool
:
return
isinstance
(
obj
,
type
(
dataclasses
.
MISSING
))
class
ConfigBase
:
"""
Base class of config classes.
Subclass may override `_canonical_rules` and `_validation_rules`,
and `validate()` if the logic is complex.
"""
# Rules to convert field value to canonical format.
# The key is field name.
# The value is callable `value -> canonical_value`
# It is not type-hinted so dataclass won't treat it as field
_canonical_rules
=
{}
# type: ignore
# Rules to validate field value.
# The key is field name.
# The value is callable `value -> valid` or `value -> (valid, error_message)`
# The rule will be called with canonical format and is only called when `value` is not None.
# `error_message` is used when `valid` is False.
# It will be prepended with class name and field name in exception message.
_validation_rules
=
{}
# type: ignore
def
__init__
(
self
,
*
,
_base_path
:
Optional
[
Path
]
=
None
,
**
kwargs
):
"""
Initialize a config object and set some fields.
Name of keyword arguments can either be snake_case or camelCase.
They will be converted to snake_case automatically.
If a field is missing and don't have default value, it will be set to `dataclasses.MISSING`.
"""
kwargs
=
{
util
.
case_insensitive
(
key
):
value
for
key
,
value
in
kwargs
.
items
()}
if
_base_path
is
None
:
_base_path
=
Path
()
for
field
in
dataclasses
.
fields
(
self
):
if
field
.
name
in
kwargs
:
setattr
(
self
,
field
.
name
,
kwargs
[
field
.
name
])
elif
field
.
default
!=
dataclasses
.
MISSING
:
setattr
(
self
,
field
.
name
,
field
.
default
)
else
:
setattr
(
self
,
field
.
name
,
type
(
self
).
_placeholder
[
field
.
name
])
value
=
kwargs
.
pop
(
util
.
case_insensitive
(
field
.
name
),
field
.
default
)
if
value
is
not
None
and
not
_is_missing
(
value
):
# relative paths loaded from config file are not relative to pwd
if
'Path'
in
str
(
field
.
type
):
value
=
Path
(
value
).
expanduser
()
if
not
value
.
is_absolute
():
value
=
_base_path
/
value
# convert nested dict to config type
if
isinstance
(
value
,
dict
):
cls
=
util
.
strip_optional
(
field
.
type
)
if
isinstance
(
cls
,
type
)
and
issubclass
(
cls
,
ConfigBase
):
value
=
cls
(
**
value
,
_base_path
=
_base_path
)
setattr
(
self
,
field
.
name
,
value
)
if
kwargs
:
cls
=
type
(
self
).
__name__
fields
=
', '
.
join
(
kwargs
.
keys
())
raise
ValueError
(
f
'
{
cls
}
: Unrecognized fields
{
fields
}
'
)
@
classmethod
def
load
(
cls
:
Type
[
T
],
path
:
PathLike
)
->
T
:
"""
Load config from YAML (or JSON) file.
Keys in YAML file can either be camelCase or snake_case.
"""
data
=
yaml
.
safe_load
(
open
(
path
))
if
not
isinstance
(
data
,
dict
):
raise
ValueError
(
f
'Content of config file
{
path
}
is not a dict/object'
)
return
cls
(
**
data
,
_base_path
=
Path
(
path
).
parent
)
def
json
(
self
)
->
Dict
[
str
,
Any
]:
"""
Convert config to JSON object.
The keys of returned object will be camelCase.
"""
return
dataclasses
.
asdict
(
self
.
canonical
(),
dict_factory
=
lambda
items
:
dict
((
util
.
camel_case
(
k
),
v
)
for
k
,
v
in
items
if
v
is
not
None
)
)
def
canonical
(
self
:
T
)
->
T
:
"""
Returns a deep copy, where the fields supporting multiple formats are converted to the canonical format.
Noticeably, relative path may be converted to absolute path.
"""
ret
=
copy
.
deepcopy
(
self
)
for
field
in
dataclasses
.
fields
(
ret
):
key
,
value
=
field
.
name
,
getattr
(
ret
,
field
.
name
)
rule
=
ret
.
_canonical_rules
.
get
(
key
)
if
rule
is
not
None
:
setattr
(
ret
,
key
,
rule
(
value
))
elif
isinstance
(
value
,
ConfigBase
):
setattr
(
ret
,
key
,
value
.
canonical
())
# value will be copied twice, should not be a performance issue anyway
return
ret
def
validate
(
self
)
->
None
:
# check existence
for
key
,
placeholder_value
in
type
(
self
).
_placeholder
.
items
():
if
getattr
(
self
,
key
)
==
placeholder_value
:
raise
ValueError
(
f
'Field "
{
key
}
" is not set'
)
# TODO: check type
# check value
for
key
,
condition
in
type
(
self
).
_value_range
.
items
():
value
=
getattr
(
self
,
key
)
if
not
condition
(
value
):
raise
ValueError
(
f
'Field "
{
key
}
" (
{
repr
(
value
)
}
) out of range'
)
# check special fields
if
not
Path
(
self
.
trial_code_directory
).
is_dir
():
raise
ValueError
(
f
'Trial code directory "
{
self
.
trial_code_directory
}
" does not exist or is not directory'
)
def
experiment_config_json
(
self
)
->
Dict
[
str
,
Any
]:
# this only contains the common part for most (if not all) training services
# subclasses should override it to provide exclusive fields
return
{
'authorName'
:
'_'
,
'experimentName'
:
self
.
experiment_name
,
'trialConcurrency'
:
self
.
trial_concurrency
,
'maxExecDuration'
:
self
.
max_execution_seconds
or
(
999
*
24
*
3600
),
'maxTrialNum'
:
self
.
max_trial_number
or
99999
,
'searchSpace'
:
json
.
dumps
(
self
.
search_space
),
'trainingServicePlatform'
:
self
.
_training_service
,
'tuner'
:
{
'builtinTunerName'
:
'_user_created_'
},
**
(
self
.
extra_config
or
{})
}
def
cluster_metadata_json
(
self
)
->
Any
:
# the cluster metadata format is a total mess
# leave it to each subclass before we refactoring nni manager
raise
NotImplementedError
()
@
staticmethod
def
create_template
(
training_service
:
str
)
->
'ExperimentConfig'
:
for
cls
in
ExperimentConfig
.
__subclasses__
():
for
field
in
dataclasses
.
fields
(
cls
):
if
field
.
name
==
'_training_service'
and
field
.
default
==
training_service
:
return
cls
()
raise
ValueError
(
f
'Unrecognized training service
{
training_service
}
'
)
class
RetiariiExpConfig
(
ExperimentConfig
):
@
staticmethod
def
create_template
(
training_service
:
str
)
->
'ExperimentConfig'
:
for
cls
in
ExperimentConfig
.
__subclasses__
():
for
field
in
dataclasses
.
fields
(
cls
):
if
field
.
name
==
'_training_service'
and
field
.
default
==
training_service
:
config_obj
=
cls
()
config_obj
.
search_space
=
{}
config_obj
.
trial_command
=
'python3 -m nni.retiarii.trial_entry'
# FIXME: expose this field to users
config_obj
.
trial_code_directory
=
'../..'
return
config_obj
"""
Validate the config object and raise Exception if it's ill-formed.
"""
class_name
=
type
(
self
).
__name__
config
=
self
.
canonical
()
for
field
in
dataclasses
.
fields
(
config
):
key
,
value
=
field
.
name
,
getattr
(
config
,
field
.
name
)
# check existence
if
_is_missing
(
value
):
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
is not set'
)
# check type (TODO)
type_name
=
str
(
field
.
type
).
replace
(
'typing.'
,
''
)
optional
=
any
([
type_name
.
startswith
(
'Optional['
),
type_name
.
startswith
(
'Union['
)
and
'NoneType'
in
type_name
,
type_name
==
'Any'
])
if
value
is
None
:
if
optional
:
continue
else
:
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
cannot be None'
)
# check value
rule
=
config
.
_validation_rules
.
get
(
key
)
if
rule
is
not
None
:
try
:
result
=
rule
(
value
)
except
Exception
:
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
has bad value
{
repr
(
value
)
}
'
)
if
isinstance
(
result
,
bool
):
if
not
result
:
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
(
{
repr
(
value
)
}
) is out of range'
)
else
:
if
not
result
[
0
]:
raise
ValueError
(
f
'
{
class_name
}
:
{
key
}
{
result
[
1
]
}
'
)
# check nested config
if
isinstance
(
value
,
ConfigBase
):
value
.
validate
()
nni/experiment/config/common.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
.base
import
ConfigBase
,
PathLike
from
.
import
util
__all__
=
[
'ExperimentConfig'
,
'AlgorithmConfig'
,
'CustomAlgorithmConfig'
,
'TrainingServiceConfig'
,
]
@
dataclass
(
init
=
False
)
class
_AlgorithmConfig
(
ConfigBase
):
name
:
Optional
[
str
]
=
None
class_name
:
Optional
[
str
]
=
None
code_directory
:
Optional
[
PathLike
]
=
None
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
def
validate
(
self
):
super
().
validate
()
_validate_algo
(
self
)
@
dataclass
(
init
=
False
)
class
AlgorithmConfig
(
_AlgorithmConfig
):
name
:
str
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
@
dataclass
(
init
=
False
)
class
CustomAlgorithmConfig
(
_AlgorithmConfig
):
class_name
:
str
class_directory
:
Optional
[
PathLike
]
=
None
class_args
:
Optional
[
Dict
[
str
,
Any
]]
=
None
class
TrainingServiceConfig
(
ConfigBase
):
platform
:
str
@
dataclass
(
init
=
False
)
class
ExperimentConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
search_space_file
:
Optional
[
PathLike
]
=
None
search_space
:
Any
=
None
trial_command
:
str
trial_code_directory
:
PathLike
=
'.'
trial_concurrency
:
int
trial_gpu_number
:
int
=
0
max_experiment_duration
:
Optional
[
str
]
=
None
max_trial_number
:
Optional
[
int
]
=
None
nni_manager_ip
:
Optional
[
str
]
=
None
use_annotation
:
bool
=
False
debug
:
bool
=
False
log_level
:
Optional
[
str
]
=
None
experiment_working_directory
:
Optional
[
PathLike
]
=
None
tuner_gpu_indices
:
Optional
[
Union
[
List
[
int
],
str
]]
=
None
tuner
:
Optional
[
_AlgorithmConfig
]
=
None
accessor
:
Optional
[
_AlgorithmConfig
]
=
None
advisor
:
Optional
[
_AlgorithmConfig
]
=
None
training_service
:
TrainingServiceConfig
def
__init__
(
self
,
training_service_platform
:
Optional
[
str
]
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
if
training_service_platform
is
not
None
:
assert
'training_service'
not
in
kwargs
self
.
training_service
=
util
.
training_service_config_factory
(
training_service_platform
)
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
super
().
validate
()
if
initialized_tuner
:
_validate_for_exp
(
self
)
else
:
_validate_for_nnictl
(
self
)
## End of public API ##
@
property
def
_canonical_rules
(
self
):
return
_canonical_rules
@
property
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
'search_space_file'
:
util
.
canonical_path
,
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
'experiment_working_directory'
:
util
.
canonical_path
,
'tuner_gpu_indices'
:
lambda
value
:
[
int
(
idx
)
for
idx
in
value
.
split
(
','
)]
if
isinstance
(
value
,
str
)
else
value
}
_validation_rules
=
{
'search_space_file'
:
lambda
value
:
(
Path
(
value
).
is_file
(),
f
'"
{
value
}
" does not exist or is not regular file'
),
'trial_code_directory'
:
lambda
value
:
(
Path
(
value
).
is_dir
(),
f
'"
{
value
}
" does not exist or is not directory'
),
'trial_concurrency'
:
lambda
value
:
value
>
0
,
'trial_gpu_number'
:
lambda
value
:
value
>=
0
,
'max_experiment_duration'
:
lambda
value
:
util
.
parse_time
(
value
)
>
0
,
'max_trial_number'
:
lambda
value
:
value
>
0
,
'log_level'
:
lambda
value
:
value
in
[
"trace"
,
"debug"
,
"info"
,
"warning"
,
"error"
,
"fatal"
],
'tuner_gpu_indices'
:
lambda
value
:
all
(
i
>=
0
for
i
in
value
)
and
len
(
value
)
==
len
(
set
(
value
)),
'training_service'
:
lambda
value
:
(
type
(
value
)
is
not
TrainingServiceConfig
,
'cannot be abstract base class'
)
}
def
_validate_for_exp
(
config
:
ExperimentConfig
)
->
None
:
# validate experiment for nni.Experiment, where tuner is already initialized outside
if
config
.
use_annotation
:
raise
ValueError
(
'ExperimentConfig: annotation is not supported in this mode'
)
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
util
.
count
(
config
.
tuner
,
config
.
accessor
,
config
.
advisor
)
!=
0
:
raise
ValueError
(
'ExperimentConfig: tuner, accessor, and advisor must not be set in for this mode'
)
if
config
.
tuner_gpu_indices
is
not
None
:
raise
ValueError
(
'ExperimentConfig: tuner_gpu_indices is not supported in this mode'
)
def
_validate_for_nnictl
(
config
:
ExperimentConfig
)
->
None
:
# validate experiment for normal launching approach
if
config
.
use_annotation
:
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
0
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must not be set with annotationn'
)
else
:
if
util
.
count
(
config
.
search_space
,
config
.
search_space_file
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: search_space and search_space_file must be set one'
)
if
util
.
count
(
config
.
tuner
,
config
.
advisor
)
!=
1
:
raise
ValueError
(
'ExperimentConfig: tuner and advisor must be set one'
)
def
_validate_algo
(
algo
:
AlgorithmConfig
)
->
None
:
if
algo
.
name
is
None
:
if
algo
.
class_name
is
None
:
raise
ValueError
(
'Missing algorithm name'
)
if
algo
.
code_directory
is
not
None
and
not
Path
(
algo
.
code_directory
).
is_dir
():
raise
ValueError
(
f
'code_directory "
{
algo
.
code_directory
}
" does not exist or is not directory'
)
else
:
if
algo
.
class_name
is
not
None
or
algo
.
code_directory
is
not
None
:
raise
ValueError
(
f
'When name is set for registered algorithm, class_name and code_directory cannot be used'
)
# TODO: verify algorithm installation and class args
nni/experiment/config/convert.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
from
pathlib
import
Path
from
tempfile
import
NamedTemporaryFile
from
typing
import
Any
,
Dict
,
List
from
.common
import
ExperimentConfig
from
.
import
util
_logger
=
logging
.
getLogger
(
__name__
)
def
to_v1_yaml
(
config
:
ExperimentConfig
,
skip_nnictl
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
config
.
validate
(
skip_nnictl
)
data
=
config
.
json
()
ts
=
data
.
pop
(
'trainingService'
)
if
ts
[
'platform'
]
==
'openpai'
:
ts
[
'platform'
]
=
'pai'
data
[
'authorName'
]
=
'N/A'
data
[
'experimentName'
]
=
data
.
get
(
'experimentName'
,
'N/A'
)
data
[
'maxExecDuration'
]
=
data
.
pop
(
'maxExperimentDuration'
,
'999d'
)
if
data
[
'debug'
]:
data
[
'versionCheck'
]
=
False
data
[
'maxTrialNum'
]
=
data
.
pop
(
'maxTrialNumber'
,
99999
)
data
[
'trainingServicePlatform'
]
=
ts
[
'platform'
]
ss
=
data
.
pop
(
'searchSpace'
,
None
)
ss_file
=
data
.
pop
(
'searchSpaceFile'
,
None
)
if
ss
is
not
None
:
ss_file
=
NamedTemporaryFile
(
'w'
,
delete
=
False
)
json
.
dump
(
ss
,
ss_file
,
indent
=
4
)
data
[
'searchSpacePath'
]
=
ss_file
.
name
elif
ss_file
is
not
None
:
data
[
'searchSpacePath'
]
=
ss_file
if
'experimentWorkingDirectory'
in
data
:
data
[
'logDir'
]
=
data
.
pop
(
'experimentWorkingDirectory'
)
for
algo_type
in
[
'tuner'
,
'assessor'
,
'advisor'
]:
algo
=
data
.
get
(
algo_type
)
if
algo
is
None
:
continue
if
algo
[
'name'
]
is
not
None
:
# builtin
algo
[
'builtin'
+
algo_type
.
title
()
+
'Name'
]
=
algo
.
pop
(
'name'
)
algo
.
pop
(
'className'
,
None
)
algo
.
pop
(
'codeDirectory'
,
None
)
else
:
algo
.
pop
(
'name'
,
None
)
class_name_parts
=
algo
.
pop
(
'className'
).
split
(
'.'
)
algo
[
'codeDir'
]
=
algo
.
pop
(
'codeDirectory'
,
''
)
+
'/'
.
join
(
class_name_parts
[:
-
2
])
algo
[
'classFileName'
]
=
class_name_parts
[
-
2
]
+
'.py'
algo
[
'className'
]
=
class_name_parts
[
-
1
]
tuner_gpu_indices
=
_convert_gpu_indices
(
data
.
pop
(
'tunerGpuIndices'
,
None
))
if
tuner_gpu_indices
is
not
None
:
data
[
'tuner'
][
'gpuIndicies'
]
=
tuner_gpu_indices
data
[
'trial'
]
=
{
'command'
:
data
.
pop
(
'trialCommand'
),
'codeDir'
:
data
.
pop
(
'trialCodeDirectory'
),
'gpuNum'
:
data
.
pop
(
'trialGpuNumber'
,
''
)
}
if
ts
[
'platform'
]
==
'local'
:
data
[
'localConfig'
]
=
{
'useActiveGpu'
:
ts
[
'useActiveGpu'
],
'maxTrialNumPerGpu'
:
ts
[
'maxTrialNumberPerGpu'
]
}
if
ts
.
get
(
'gpuIndices'
)
is
not
None
:
data
[
'localConfig'
][
'gpuIndices'
]
=
','
.
join
(
str
(
idx
)
for
idx
in
ts
[
'gpuIndices'
])
elif
ts
[
'platform'
]
==
'remote'
:
data
[
'remoteConfig'
]
=
{
'reuse'
:
ts
[
'reuseMode'
]}
data
[
'machineList'
]
=
[]
for
machine
in
ts
[
'machineList'
]:
machine
=
{
'ip'
:
machine
[
'host'
],
'username'
:
machine
[
'user'
],
'passwd'
:
machine
[
'password'
],
'sshKeyPath'
:
machine
[
'sshKeyFile'
],
'passphrase'
:
machine
[
'sshPassphrase'
],
'gpuIndices'
:
_convert_gpu_indices
(
machine
[
'gpuIndices'
]),
'maxTrialNumPerGpu'
:
machine
[
'maxTrialNumPerGpu'
],
'useActiveGpu'
:
machine
[
'useActiveGpu'
],
'preCommand'
:
machine
[
'trialPrepareCommand'
]
}
elif
ts
[
'platform'
]
==
'pai'
:
data
[
'trial'
][
'cpuNum'
]
=
ts
[
'trialCpuNumber'
]
data
[
'trial'
][
'memoryMB'
]
=
util
.
parse_size
(
ts
[
'trialMemorySize'
])
data
[
'trial'
][
'image'
]
=
ts
[
'docker_image'
]
data
[
'paiConfig'
]
=
{
'userName'
:
ts
[
'username'
],
'token'
:
ts
[
'token'
],
'host'
:
'https://'
+
ts
[
'host'
],
'reuse'
:
ts
[
'reuseMode'
]
}
return
data
def
_convert_gpu_indices
(
indices
):
return
','
.
join
(
str
(
idx
)
for
idx
in
indices
)
if
indices
is
not
None
else
None
def
to_cluster_metadata
(
config
:
ExperimentConfig
)
->
List
[
Dict
[
str
,
Any
]]:
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
ret
=
[]
if
config
.
training_service
.
platform
==
'local'
:
request_data
=
dict
()
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
if
request_data
[
'local_config'
]:
if
request_data
[
'local_config'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'local_config'
][
'gpuIndices'
]
=
str
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
))
if
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
):
request_data
[
'local_config'
][
'maxTrialNumOnEachGpu'
]
=
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
)
if
request_data
[
'local_config'
].
get
(
'useActiveGpu'
):
request_data
[
'local_config'
][
'useActiveGpu'
]
=
request_data
[
'local_config'
].
get
(
'useActiveGpu'
)
ret
.
append
(
request_data
)
elif
config
.
training_service
.
platform
==
'remote'
:
request_data
=
dict
()
if
experiment_config
.
get
(
'remoteConfig'
):
request_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
else
:
request_data
[
'remote_config'
]
=
{
'reuse'
:
False
}
request_data
[
'machine_list'
]
=
experiment_config
[
'machineList'
]
if
request_data
[
'machine_list'
]:
for
i
in
range
(
len
(
request_data
[
'machine_list'
])):
if
isinstance
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
),
int
):
request_data
[
'machine_list'
][
i
][
'gpuIndices'
]
=
str
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
))
ret
.
append
(
request_data
)
elif
config
.
training_service
.
platform
==
'openpai'
:
pai_config_data
=
dict
()
pai_config_data
[
'pai_config'
]
=
experiment_config
[
'paiConfig'
]
ret
.
append
(
pai_config_data
)
else
:
raise
RuntimeError
(
'Unsupported training service '
+
config
.
training_service
.
platform
)
if
experiment_config
.
get
(
'nniManagerIp'
)
is
not
None
:
ret
.
append
({
'nni_manager_ip'
:
{
'nniManagerIp'
:
experiment_config
[
'nniManagerIp'
]}})
ret
.
append
({
'trial_config'
:
experiment_config
[
'trial'
]})
return
ret
def
to_rest_json
(
config
:
ExperimentConfig
)
->
Dict
[
str
,
Any
]:
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
request_data
=
dict
()
request_data
[
'authorName'
]
=
experiment_config
[
'authorName'
]
request_data
[
'experimentName'
]
=
experiment_config
[
'experimentName'
]
request_data
[
'trialConcurrency'
]
=
experiment_config
[
'trialConcurrency'
]
request_data
[
'maxExecDuration'
]
=
util
.
parse_time
(
experiment_config
[
'maxExecDuration'
])
request_data
[
'maxTrialNum'
]
=
experiment_config
[
'maxTrialNum'
]
if
config
.
search_space
is
not
None
:
request_data
[
'searchSpace'
]
=
json
.
dumps
(
config
.
search_space
)
else
:
request_data
[
'searchSpace'
]
=
Path
(
config
.
search_space_file
).
read_text
()
request_data
[
'trainingServicePlatform'
]
=
experiment_config
.
get
(
'trainingServicePlatform'
)
if
experiment_config
.
get
(
'advisor'
):
request_data
[
'advisor'
]
=
experiment_config
[
'advisor'
]
if
request_data
[
'advisor'
].
get
(
'gpuNum'
):
_logger
.
warning
(
'gpuNum is deprecated, please use gpuIndices instead.'
)
if
request_data
[
'advisor'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'advisor'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'advisor'
][
'gpuIndices'
]
=
str
(
request_data
[
'advisor'
].
get
(
'gpuIndices'
))
elif
experiment_config
.
get
(
'tuner'
):
request_data
[
'tuner'
]
=
experiment_config
[
'tuner'
]
if
request_data
[
'tuner'
].
get
(
'gpuNum'
):
_logger
.
warning
(
'gpuNum is deprecated, please use gpuIndices instead.'
)
if
request_data
[
'tuner'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'tuner'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'tuner'
][
'gpuIndices'
]
=
str
(
request_data
[
'tuner'
].
get
(
'gpuIndices'
))
if
'assessor'
in
experiment_config
:
request_data
[
'assessor'
]
=
experiment_config
[
'assessor'
]
if
request_data
[
'assessor'
].
get
(
'gpuNum'
):
_logger
.
warning
(
'gpuNum is deprecated, please remove it from your config file.'
)
else
:
request_data
[
'tuner'
]
=
{
'builtinTunerName'
:
'_user_created_'
}
#debug mode should disable version check
if
experiment_config
.
get
(
'debug'
)
is
not
None
:
request_data
[
'versionCheck'
]
=
not
experiment_config
.
get
(
'debug'
)
#validate version check
if
experiment_config
.
get
(
'versionCheck'
)
is
not
None
:
request_data
[
'versionCheck'
]
=
experiment_config
.
get
(
'versionCheck'
)
if
experiment_config
.
get
(
'logCollection'
):
request_data
[
'logCollection'
]
=
experiment_config
.
get
(
'logCollection'
)
request_data
[
'clusterMetaData'
]
=
[]
if
experiment_config
[
'trainingServicePlatform'
]
==
'local'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'codeDir'
,
'value'
:
experiment_config
[
'trial'
][
'codeDir'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'command'
,
'value'
:
experiment_config
[
'trial'
][
'command'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'remote'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'machine_list'
,
'value'
:
experiment_config
[
'machineList'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
if
not
experiment_config
.
get
(
'remoteConfig'
):
# set default value of reuse in remoteConfig to False
experiment_config
[
'remoteConfig'
]
=
{
'reuse'
:
False
}
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'remote_config'
,
'value'
:
experiment_config
[
'remoteConfig'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'pai'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'pai_config'
,
'value'
:
experiment_config
[
'paiConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'kubeflow'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'kubeflow_config'
,
'value'
:
experiment_config
[
'kubeflowConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'frameworkcontroller'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'frameworkcontroller_config'
,
'value'
:
experiment_config
[
'frameworkcontrollerConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
elif
experiment_config
[
'trainingServicePlatform'
]
==
'aml'
:
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
return
request_data
nni/experiment/config/local.py
View file @
3ec26b40
...
...
@@ -2,39 +2,25 @@
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Any
,
Dict
from
typing
import
List
,
Optional
,
Union
from
.
base
import
Experiment
Config
from
.
common
import
TrainingService
Config
__all__
=
[
'LocalConfig'
]
@
dataclass
(
init
=
False
)
class
LocalExperimentConfig
(
ExperimentConfig
):
use_active_gpu
:
bool
=
False
class
LocalConfig
(
TrainingServiceConfig
):
platform
:
str
=
'local'
use_active_gpu
:
bool
max_trial_number_per_gpu
:
int
=
1
gpu_indices
:
Optional
[
Union
[
List
[
int
],
str
]]
=
None
_training_service
:
str
=
'local'
_canonical_rules
=
{
'gpu_indices'
:
lambda
value
:
[
int
(
idx
)
for
idx
in
value
.
split
(
','
)]
if
isinstance
(
value
,
str
)
else
value
}
def
experiment_config_json
(
self
)
->
Dict
[
str
,
Any
]:
ret
=
super
().
experiment_config_json
()
ret
[
'clusterMetaData'
]
=
[
{
'key'
:
'codeDir'
,
'value'
:
str
(
Path
(
self
.
trial_code_directory
).
resolve
())
},
{
'key'
:
'command'
,
'value'
:
self
.
trial_command
}
]
#ret['local_config'] = {
# 'useActiveGpu': self.use_active_gpu
#}
return
ret
def
cluster_metadata_json
(
self
)
->
Any
:
return
{
'trial_config'
:
{
'command'
:
self
.
trial_command
,
'codeDir'
:
str
(
Path
(
self
.
trial_code_directory
).
resolve
())
}
}
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'local'
,
'cannot be modified'
),
'max_trial_number_per_gpu'
:
lambda
value
:
value
>
0
,
'gpu_indices'
:
lambda
value
:
all
(
idx
>=
0
for
idx
in
value
)
and
len
(
value
)
==
len
(
set
(
value
))
}
nni/experiment/config/util.py
0 → 100644
View file @
3ec26b40
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Miscellaneous utility functions.
"""
import
math
import
os.path
from
pathlib
import
Path
from
typing
import
Optional
,
Union
PathLike
=
Union
[
Path
,
str
]
def
case_insensitive
(
key
:
str
)
->
str
:
return
key
.
lower
().
replace
(
'_'
,
''
)
def
camel_case
(
key
:
str
)
->
str
:
words
=
key
.
split
(
'_'
)
return
words
[
0
]
+
''
.
join
(
word
.
title
()
for
word
in
words
[
1
:])
def
canonical_path
(
path
:
Optional
[
PathLike
])
->
Optional
[
str
]:
# Path.resolve() does not work on Windows when file not exist, so use os.path instead
return
os
.
path
.
abspath
(
os
.
path
.
expanduser
(
path
))
if
path
is
not
None
else
None
def
count
(
*
values
)
->
int
:
return
sum
(
value
is
not
None
and
value
is
not
False
for
value
in
values
)
def
training_service_config_factory
(
platform
:
str
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
for
cls
in
TrainingServiceConfig
.
__subclasses__
():
if
cls
.
platform
==
platform
:
return
cls
()
raise
ValueError
(
f
'Unrecognized platform
{
platform
}
'
)
def
strip_optional
(
type_hint
):
return
type_hint
.
__args__
[
0
]
if
str
(
type_hint
).
startswith
(
'typing.Optional['
)
else
type_hint
def
parse_time
(
time
:
str
,
target_unit
:
str
=
's'
)
->
int
:
return
_parse_unit
(
time
.
lower
(),
target_unit
,
_time_units
)
def
parse_size
(
size
:
str
,
target_unit
:
str
=
'mb'
)
->
int
:
return
_parse_unit
(
size
.
lower
(),
target_unit
,
_size_units
)
_time_units
=
{
'd'
:
24
*
3600
,
'h'
:
3600
,
'm'
:
60
,
's'
:
1
}
_size_units
=
{
'gb'
:
1024
*
1024
*
1024
,
'mb'
:
1024
*
1024
,
'kb'
:
1024
}
def
_parse_unit
(
string
,
target_unit
,
all_units
):
for
unit
,
factor
in
all_units
.
items
():
if
string
.
endswith
(
unit
):
number
=
string
[:
-
len
(
unit
)]
value
=
float
(
number
)
*
factor
return
math
.
ceil
(
value
/
all_units
[
target_unit
])
raise
ValueError
(
f
'Unsupported unit in "
{
string
}
"'
)
nni/experiment/experiment.py
View file @
3ec26b40
import
atexit
import
logging
import
socket
from
subprocess
import
Popen
import
time
from
threading
import
Thread
from
typing
import
Optional
,
overload
,
List
,
Union
,
Callable
import
time
from
typing
import
Optional
,
overload
import
colorama
import
psutil
import
nni.runtime.log
from
nni.runtime.msg_dispatcher
import
MsgDispatcher
from
nni.tuner
import
Tuner
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.retiarii.converter.graph_gen
import
convert_to_graph
from
.config
import
ExperimentConfig
from
.
import
launcher
from
.pipe
import
Pipe
from
.
import
rest
_logger
=
logging
.
getLogger
(
__name__
)
nni
.
runtime
.
log
.
init_logger_experiment
()
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
class
Experiment
:
"""
Controls an NNI experiment.
You may either create a new NNI experiment with construtor and `Experiment.start()`,
# TODO: or control an existing experiment with `Experiment.connect()`.
Create and stop an NNI experiment.
Attributes
----------
...
...
@@ -42,7 +44,7 @@ class Experiment:
Parameters
----------
tuner
A tuner instance.
# TODO: accessor / advisor
A tuner instance.
config
Experiment configuration.
"""
...
...
@@ -67,24 +69,24 @@ class Experiment:
A tuner instance.
training_service
Name of training service.
Supported value: "local", "remote", "openpai"
/"pai"
.
Supported value: "local", "remote", "openpai".
"""
...
def
__init__
(
self
,
tuner
:
Tuner
,
config
=
None
,
training_service
=
None
):
self
.
config
:
ExperimentConfig
self
.
port
:
Optional
[
int
]
=
None
self
.
_dispatcher
=
MsgDispatcher
(
tuner
,
None
)
self
.
tuner
:
Tuner
=
tuner
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
self
.
_dispatcher
:
Optional
[
MsgDispatcher
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
if
isinstance
(
config
,
str
):
config
,
training_service
=
None
,
config
if
training_service
==
'openpai'
:
training_service
=
'pai'
if
config
is
None
:
self
.
config
=
ExperimentConfig
.
create_template
(
training_service
)
self
.
config
=
ExperimentConfig
(
training_service
)
else
:
self
.
config
=
config
...
...
@@ -103,6 +105,8 @@ class Experiment:
debug
Whether to start in debug mode.
"""
atexit
.
register
(
self
.
stop
)
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
...
...
@@ -112,9 +116,20 @@ class Experiment:
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be
creat
ed after pipe initialized
# dispatcher must be
launch
ed after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
self
.
_dispatcher
=
MsgDispatcher
(
self
.
tuner
,
None
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
.
start
()
ips
=
[
self
.
config
.
nni_manager_ip
]
for
interfaces
in
psutil
.
net_if_addrs
().
values
():
for
interface
in
interfaces
:
if
interface
.
family
==
socket
.
AF_INET
:
ips
.
append
(
interface
.
address
)
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
_logger
.
info
(
msg
)
# TODO: register experiment management metadata
...
...
@@ -123,27 +138,41 @@ class Experiment:
"""
Stop background experiment.
"""
self
.
_proc
.
kill
()
self
.
_pipe
.
close
()
_logger
.
info
(
'Stopping experiment...'
)
atexit
.
unregister
(
self
.
stop
)
if
self
.
_proc
is
not
None
:
self
.
_proc
.
kill
()
if
self
.
_pipe
is
not
None
:
self
.
_pipe
.
close
()
if
self
.
_dispatcher_thread
is
not
None
:
self
.
_dispatcher
.
stopping
=
True
self
.
_dispatcher_thread
.
join
(
timeout
=
1
)
self
.
port
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher_thread
=
None
def
run
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
def
run
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
bool
:
"""
Run the experiment.
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
self
.
start
(
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
if
status
==
'STOPPED'
:
return
True
if
status
==
'ERROR'
:
return
False
finally
:
self
.
stop
()
...
...
@@ -153,97 +182,3 @@ class Experiment:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
'nn.Module'
,
trainer
:
'BaseTrainer'
,
applied_mutators
:
List
[
'Mutator'
],
strategy
:
'BaseStrategy'
,
tca
:
'TraceClassArguments'
=
None
):
self
.
config
:
ExperimentConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
base_model
=
base_model
self
.
trainer
=
trainer
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
recorded_module_args
=
tca
.
recorded_arguments
# FIXME: remove this argument
self
.
_dispatcher
=
RetiariiAdvisor
()
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_pipe
:
Optional
[
Pipe
]
=
None
def
_start_strategy
(
self
):
import
torch
script_module
=
torch
.
jit
.
script
(
self
.
base_model
)
base_model
=
convert_to_graph
(
script_module
,
self
.
base_model
,
self
.
recorded_module_args
)
assert
id
(
self
.
trainer
)
in
self
.
recorded_module_args
trainer_config
=
self
.
recorded_module_args
[
id
(
self
.
trainer
)]
_logger
.
info
(
'Starting strategy...'
)
Thread
(
target
=
self
.
strategy
.
run
,
args
=
(
base_model
,
self
.
applied_mutators
,
trainer_config
)).
start
()
_logger
.
info
(
'Strategy started!'
)
def
start
(
self
,
config
:
ExperimentConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
None
:
"""
Start the experiment in background.
This method will raise exception on failure.
If it returns, the experiment should have been successfully started.
Parameters
----------
port
The port of web UI.
debug
Whether to start in debug mode.
"""
if
debug
:
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
self
.
port
=
port
# port will be None if start up failed
# dispatcher must be created after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
Thread
(
target
=
self
.
_dispatcher
.
run
).
start
()
self
.
_start_strategy
()
# TODO: register experiment management metadata
def
stop
(
self
)
->
None
:
"""
Stop background experiment.
"""
self
.
_proc
.
kill
()
self
.
_pipe
.
close
()
self
.
port
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
def
run
(
self
,
config
:
ExperimentConfig
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
str
:
"""
Run the experiment.
This function will block until experiment finish or error.
"""
self
.
start
(
config
,
port
,
debug
)
try
:
while
True
:
time
.
sleep
(
10
)
status
=
self
.
get_status
()
if
status
in
[
'ERROR'
,
'STOPPED'
,
'NO_MORE_TRIAL'
]:
return
status
finally
:
self
.
stop
()
def
get_status
(
self
)
->
str
:
if
self
.
port
is
None
:
raise
RuntimeError
(
'Experiment is not running'
)
resp
=
rest
.
get
(
self
.
port
,
'/check-status'
)
return
resp
[
'status'
]
nni/experiment/launcher.py
View file @
3ec26b40
import
contextlib
import
logging
from
pathlib
import
Path
import
socket
from
subprocess
import
Popen
...
...
@@ -6,40 +7,46 @@ import sys
import
time
from
typing
import
Optional
,
Tuple
import
colorama
import
nni.runtime.protocol
import
nni_node
from
.config
import
ExperimentConfig
from
.config
import
convert
from
.
import
management
from
.pipe
import
Pipe
from
.
import
rest
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
def
start_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Tuple
[
Popen
,
Pipe
]:
pipe
=
None
proc
=
None
config
.
validate
()
config
.
validate
(
initialized_tuner
=
True
)
_ensure_port_idle
(
port
)
if
config
.
_
training_service
==
'pai'
:
if
config
.
training_service
.
platform
==
'
open
pai'
:
_ensure_port_idle
(
port
+
1
,
'OpenPAI requires an additional port'
)
exp_id
=
management
.
generate_experiment_id
()
try
:
print
(
f
'Creating experiment
{
exp_id
}
...
'
)
_logger
.
info
(
f
'Creating experiment
{
colorama
.
Fore
.
CYAN
}
{
exp_id
}
'
)
pipe
=
Pipe
(
exp_id
)
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
_logger
.
info
(
'Connecting IPC pipe...'
)
pipe_file
=
pipe
.
connect
()
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_out_file
=
pipe_file
print
(
'Statring web server...'
)
_logger
.
info
(
'Statring web server...'
)
_check_rest_server
(
port
)
print
(
'Setting up...'
)
_init_experiment
(
config
,
port
,
debug
)
# todo: kill on fail
_logger
.
info
(
'Setting up...'
)
_init_experiment
(
config
,
port
,
debug
)
return
proc
,
pipe
except
Exception
as
e
:
print
(
'Create experiment failed'
)
_logger
.
error
(
'Create experiment failed'
)
if
proc
is
not
None
:
with
contextlib
.
suppress
(
Exception
):
proc
.
kill
()
...
...
@@ -58,9 +65,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
def
_start_rest_server
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
pipe_path
:
str
)
->
Popen
:
ts
=
config
.
training_service
.
platform
if
ts
==
'openpai'
:
ts
=
'pai'
args
=
{
'port'
:
port
,
'mode'
:
config
.
_training_service
,
'mode'
:
ts
,
'experiment_id'
:
experiment_id
,
'start_mode'
:
'new'
,
'log_level'
:
'debug'
if
debug
else
'info'
,
...
...
@@ -77,15 +88,18 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
return
Popen
(
cmd
,
cwd
=
node_dir
)
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
10
)
->
None
:
for
_
in
range
(
retry
):
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
3
)
->
None
:
for
i
in
range
(
retry
):
with
contextlib
.
suppress
(
Exception
):
rest
.
get
(
port
,
'/check-status'
)
return
if
i
>
0
:
_logger
.
warning
(
'Timeout, retry...'
)
time
.
sleep
(
1
)
rest
.
get
(
port
,
'/check-status'
)
def
_init_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
None
:
rest
.
put
(
port
,
'/experiment/cluster-metadata'
,
config
.
cluster_metadata_json
())
rest
.
post
(
port
,
'/experiment'
,
config
.
experiment_config_json
())
for
cluster_metadata
in
convert
.
to_cluster_metadata
(
config
):
rest
.
put
(
port
,
'/experiment/cluster-metadata'
,
cluster_metadata
)
rest
.
post
(
port
,
'/experiment'
,
convert
.
to_rest_json
(
config
))
nni/experiment/nni_client.py
View file @
3ec26b40
...
...
@@ -26,7 +26,6 @@ import subprocess
import
re
import
json
import
requests
import
yaml
__all__
=
[
'ExternalExperiment'
,
...
...
@@ -260,38 +259,6 @@ class ExternalExperiment:
self
.
_endpoint
=
'http://localhost:{}'
.
format
(
self
.
_port
)
self
.
_exp_id
=
self
.
get_experiment_profile
()[
'id'
]
def
tmp_start_retiarii
(
self
,
graph_ir
,
training_approach
,
applied_mutators
,
strategy
,
exp_config
):
# prepare search space file which includes base graph IR and mutators
search_space
=
{}
search_space
[
'base_model_ir'
]
=
graph_ir
search_space
[
'applied_mutators'
]
=
applied_mutators
search_space
[
'training_approach'
]
=
training_approach
with
open
(
'search_space.json'
,
'w'
)
as
f
:
json
.
dump
(
search_space
,
f
)
# add advisor config to exp_config
exp_config
[
'searchSpacePath'
]
=
'search_space.json'
exp_config
[
'useAnnotation'
]
=
False
exp_config
[
'advisor'
]
=
{
'codeDir'
:
'.'
,
'classFileName'
:
'advisor_entry.py'
,
'className'
:
'RetiariiAdvisor'
,
'classArgs'
:
{
'strategy'
:
'{}.{}'
.
format
(
strategy
[
'filename'
],
strategy
[
'funcname'
])
}
}
# add trial config to exp_config
exp_config
[
'trial'
]
=
{
'command'
:
'python3 -m nni.retiarii.trial_entry'
,
'codeDir'
:
'../..'
,
'gpuNum'
:
0
}
# dump exp_config to nni.yml
with
open
(
'nni.yml'
,
'w'
)
as
f
:
yaml
.
dump
(
exp_config
,
f
)
# start experiment
self
.
start_experiment
(
'nni.yml'
)
def
start_experiment
(
self
,
config_file
,
port
=
None
,
debug
=
False
):
"""
Start an experiment with specified configuration file and connect to it.
...
...
nni/experiment/pipe.py
View file @
3ec26b40
...
...
@@ -3,7 +3,7 @@ import os
import
sys
if
sys
.
platform
==
'win32'
:
import
_win
32
import
_win
api
import
msvcrt
class
WindowsPipe
:
...
...
@@ -11,27 +11,27 @@ if sys.platform == 'win32':
self
.
path
:
str
=
r
'\\.\pipe\nni-'
+
experiment_id
self
.
file
=
None
self
.
_handle
=
_win
32
.
CreateNamedPipe
(
self
.
_handle
=
_win
api
.
CreateNamedPipe
(
self
.
path
,
_win
32
.
PIPE_ACCESS_DUPLEX
,
_win
32
.
PIPE_TYPE_MESSAGE
|
_win
32
.
PIPE_READMODE_MESSAGE
|
_win
32
.
PIPE_WAIT
,
_win
api
.
PIPE_ACCESS_DUPLEX
,
_win
api
.
PIPE_TYPE_MESSAGE
|
_win
api
.
PIPE_READMODE_MESSAGE
|
_win
api
.
PIPE_WAIT
,
1
,
8192
,
8192
,
0
,
_win
32
.
NULL
_win
api
.
NULL
)
def
connect
(
self
)
->
BufferedIOBase
:
_win
32
.
ConnectNamedPipe
(
self
.
_handle
,
_win
32
.
NULL
)
fd
=
msvcrt
.
open_osfhandle
(
self
.
_handle
)
self
.
file
=
os
.
fdopen
(
fd
,
'
r
wb'
)
_win
api
.
ConnectNamedPipe
(
self
.
_handle
,
_win
api
.
NULL
)
fd
=
msvcrt
.
open_osfhandle
(
self
.
_handle
,
0
)
self
.
file
=
os
.
fdopen
(
fd
,
'w
+
b'
)
return
self
.
file
def
close
(
self
)
->
None
:
if
self
.
file
is
not
None
:
self
.
file
.
close
()
_win
32
.
CloseHandle
(
self
.
_handle
)
_win
api
.
CloseHandle
(
self
.
_handle
)
Pipe
=
WindowsPipe
...
...
@@ -52,7 +52,7 @@ else:
def
connect
(
self
)
->
BufferedIOBase
:
conn
,
_
=
self
.
_socket
.
accept
()
self
.
file
=
conn
.
makefile
(
'
r
wb'
)
self
.
file
=
conn
.
makefile
(
'w
+
b'
)
return
self
.
file
def
close
(
self
)
->
None
:
...
...
nni/runtime/log.py
View file @
3ec26b40
...
...
@@ -4,8 +4,11 @@ import logging
from
logging
import
FileHandler
,
Formatter
,
Handler
,
StreamHandler
from
pathlib
import
Path
import
sys
import
time
from
typing
import
Optional
import
colorama
from
.env_vars
import
dispatcher_env_vars
,
trial_env_vars
...
...
@@ -17,6 +20,8 @@ def init_logger() -> None:
The detection should work in most cases but for `nnictl` and `nni.experiment`.
They will be identified as "standalone" mode and must configure the logger by themselves.
"""
colorama
.
init
()
if
dispatcher_env_vars
.
SDK_PROCESS
==
'dispatcher'
:
_init_logger_dispatcher
()
return
...
...
@@ -33,6 +38,15 @@ def init_logger() -> None:
_init_logger_standalone
()
def
init_logger_experiment
()
->
None
:
"""
Initialize logger for `nni.experiment.Experiment`.
This function will get invoked after `init_logger()`.
"""
formatter
.
format
=
_colorful_format
time_format
=
'%Y-%m-%d %H:%M:%S'
formatter
=
Formatter
(
...
...
@@ -40,14 +54,14 @@ formatter = Formatter(
time_format
)
def
_init_logger_dispatcher
()
->
None
:
log_level_map
=
{
'fatal'
:
logging
.
CRITICAL
,
'error'
:
logging
.
ERROR
,
'warning'
:
logging
.
WARNING
,
'info'
:
logging
.
INFO
,
'debug'
:
logging
.
DEBUG
'debug'
:
logging
.
DEBUG
,
'trace'
:
0
}
log_path
=
_prepare_log_dir
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
)
/
'dispatcher.log'
...
...
@@ -93,6 +107,21 @@ def _setup_logger(name: str, handler: Handler, level: int) -> None:
logger
.
setLevel
(
level
)
logger
.
propagate
=
False
def
_colorful_format
(
record
):
if
record
.
levelno
>=
logging
.
ERROR
:
color
=
colorama
.
Fore
.
RED
elif
record
.
levelno
>=
logging
.
WARNING
:
color
=
colorama
.
Fore
.
YELLOW
elif
record
.
levelno
>=
logging
.
INFO
:
color
=
colorama
.
Fore
.
GREEN
else
:
color
=
colorama
.
Fore
.
BLUE
msg
=
color
+
(
record
.
msg
%
record
.
args
)
+
colorama
.
Style
.
RESET_ALL
time
=
formatter
.
formatTime
(
record
,
time_format
)
if
record
.
levelno
<
logging
.
INFO
:
return
'[{}] {}:{} {}'
.
format
(
time
,
record
.
threadName
,
record
.
name
,
msg
)
else
:
return
'[{}] {}'
.
format
(
time
,
msg
)
class
_LogFileWrapper
(
TextIOBase
):
# wrap the logger file so that anything written to it will automatically get formatted
...
...
nni/runtime/msg_dispatcher_base.py
View file @
3ec26b40
...
...
@@ -25,11 +25,11 @@ class MsgDispatcherBase(Recoverable):
"""
def
__init__
(
self
):
self
.
stopping
=
False
if
multi_thread_enabled
():
self
.
pool
=
ThreadPool
()
self
.
thread_results
=
[]
else
:
self
.
stopping
=
False
self
.
default_command_queue
=
Queue
()
self
.
assessor_command_queue
=
Queue
()
self
.
default_worker
=
threading
.
Thread
(
target
=
self
.
command_queue_worker
,
args
=
(
self
.
default_command_queue
,))
...
...
@@ -43,11 +43,11 @@ class MsgDispatcherBase(Recoverable):
"""Run the tuner.
This function will never return unless raise.
"""
_logger
.
info
(
'
Start d
ispatcher'
)
_logger
.
info
(
'
D
ispatcher
started
'
)
if
dispatcher_env_vars
.
NNI_MODE
==
'resume'
:
self
.
load_checkpoint
()
while
True
:
while
not
self
.
stopping
:
command
,
data
=
receive
()
if
data
:
data
=
json_tricks
.
loads
(
data
)
...
...
@@ -75,7 +75,7 @@ class MsgDispatcherBase(Recoverable):
self
.
default_worker
.
join
()
self
.
assessor_worker
.
join
()
_logger
.
info
(
'
T
erminated
by NNI manager
'
)
_logger
.
info
(
'
Dispatcher t
ermin
i
ated'
)
def
command_queue_worker
(
self
,
command_queue
):
"""Process commands in command queues.
...
...
Prev
1
…
10
11
12
13
14
15
16
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment