Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
4784cc6c
Unverified
Commit
4784cc6c
authored
Jan 14, 2021
by
liuzhe-lz
Committed by
GitHub
Jan 14, 2021
Browse files
Merge pull request #3302 from microsoft/v2.0-merge
Merge branch v2.0 into master (no squash)
parents
25db55ca
349ead41
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
242 additions
and
243 deletions
+242
-243
examples/trials/network_morphism/cifar10/config_paiYarn.yml
examples/trials/network_morphism/cifar10/config_paiYarn.yml
+0
-39
examples/trials/network_morphism/requirements.txt
examples/trials/network_morphism/requirements.txt
+1
-1
examples/trials/sklearn/classification/config_paiYarn.yml
examples/trials/sklearn/classification/config_paiYarn.yml
+0
-32
examples/trials/sklearn/regression/config_paiYarn.yml
examples/trials/sklearn/regression/config_paiYarn.yml
+0
-32
nni/__init__.py
nni/__init__.py
+4
-1
nni/algorithms/compression/pytorch/quantization/quantizers.py
...algorithms/compression/pytorch/quantization/quantizers.py
+25
-26
nni/algorithms/nas/pytorch/cdarts/mutator.py
nni/algorithms/nas/pytorch/cdarts/mutator.py
+1
-1
nni/algorithms/nas/pytorch/spos/trainer.py
nni/algorithms/nas/pytorch/spos/trainer.py
+1
-1
nni/compression/pytorch/compressor.py
nni/compression/pytorch/compressor.py
+20
-12
nni/compression/pytorch/speedup/infer_shape.py
nni/compression/pytorch/speedup/infer_shape.py
+16
-18
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+8
-1
nni/experiment/config/common.py
nni/experiment/config/common.py
+9
-5
nni/experiment/config/convert.py
nni/experiment/config/convert.py
+63
-37
nni/experiment/config/openpai.py
nni/experiment/config/openpai.py
+5
-2
nni/experiment/config/remote.py
nni/experiment/config/remote.py
+8
-1
nni/experiment/config/util.py
nni/experiment/config/util.py
+20
-6
nni/experiment/experiment.py
nni/experiment/experiment.py
+29
-13
nni/experiment/launcher.py
nni/experiment/launcher.py
+30
-12
nni/experiment/pipe.py
nni/experiment/pipe.py
+0
-1
nni/nas/pytorch/base_mutator.py
nni/nas/pytorch/base_mutator.py
+2
-2
No files found.
examples/trials/network_morphism/cifar10/config_paiYarn.yml
deleted
100644 → 0
View file @
25db55ca
authorName
:
default
experimentName
:
example_cifar10-network-morphism
trialConcurrency
:
1
maxExecDuration
:
24h
maxTrialNum
:
10
#choice: local, remote, pai
trainingServicePlatform
:
paiYarn
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, NetworkMorphism
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
NetworkMorphism
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
# for now, this tuner only supports cv domain
task
:
cv
#input image width
input_width
:
32
#input image channel
input_channel
:
3
#number of classes
n_output_node
:
10
trial
:
command
:
python3 cifar10_keras.py
codeDir
:
.
gpuNum
:
1
cpuNum
:
1
memoryMB
:
8196
#The docker image to run nni job on pai
image
:
msranni/nni:latest
paiYarnConfig
:
#The username to login pai
userName
:
username
#The password to login pai
passWord
:
password
#The host of restful server of pai
host
:
10.10.10.10
\ No newline at end of file
examples/trials/network_morphism/requirements.txt
View file @
4784cc6c
numpy==1.1
4.2
numpy==1.1
9.3
tensorflow==1.15.4
tensorflow==1.15.4
torchvision==0.2.1
torchvision==0.2.1
Keras==2.3.1
Keras==2.3.1
...
...
examples/trials/sklearn/classification/config_paiYarn.yml
deleted
100644 → 0
View file @
25db55ca
authorName
:
default
experimentName
:
example_sklearn
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
100
#choice: local, remote, pai
trainingServicePlatform
:
paiYarn
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner,MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 main.py
codeDir
:
.
gpuNum
:
0
cpuNum
:
1
memoryMB
:
8196
#The docker image to run nni job on pai
image
:
msranni/nni:latest
paiYarnConfig
:
#The username to login pai
userName
:
username
#The password to login pai
passWord
:
password
#The host of restful server of pai
host
:
10.10.10.10
\ No newline at end of file
examples/trials/sklearn/regression/config_paiYarn.yml
deleted
100644 → 0
View file @
25db55ca
authorName
:
default
experimentName
:
example_sklearn
trialConcurrency
:
1
maxExecDuration
:
1h
maxTrialNum
:
100
#choice: local, remote, pai
trainingServicePlatform
:
paiYarn
searchSpacePath
:
search_space.json
#choice: true, false
useAnnotation
:
false
tuner
:
#choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner
#SMAC (SMAC should be installed through nnictl)
builtinTunerName
:
TPE
classArgs
:
#choice: maximize, minimize
optimize_mode
:
maximize
trial
:
command
:
python3 main.py
codeDir
:
.
gpuNum
:
0
cpuNum
:
1
memoryMB
:
8196
#The docker image to run nni job on pai
image
:
msranni/nni:latest
paiYarnConfig
:
#The username to login pai
userName
:
username
#The password to login pai
passWord
:
password
#The host of restful server of pai
host
:
10.10.10.10
\ No newline at end of file
nni/__init__.py
View file @
4784cc6c
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
__version__
=
'999.0.0-developing'
try
:
from
.version
import
__version__
except
ModuleNotFoundError
:
__version__
=
'999.dev0'
from
.runtime.log
import
init_logger
from
.runtime.log
import
init_logger
init_logger
()
init_logger
()
...
...
nni/algorithms/compression/pytorch/quantization/quantizers.py
View file @
4784cc6c
...
@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer):
...
@@ -41,7 +41,7 @@ class NaiveQuantizer(Quantizer):
wrapper
.
module
.
weight
=
weight
wrapper
.
module
.
weight
=
weight
return
weight
return
weight
def
update_ema
(
biased_ema
,
value
,
decay
,
step
):
def
update_ema
(
biased_ema
,
value
,
decay
):
"""
"""
calculate biased stat and unbiased stat in each step using exponential moving average method
calculate biased stat and unbiased stat in each step using exponential moving average method
...
@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step):
...
@@ -53,16 +53,13 @@ def update_ema(biased_ema, value, decay, step):
current stat value
current stat value
decay : float
decay : float
the weight of previous stat value, larger means smoother curve
the weight of previous stat value, larger means smoother curve
step : int
current step
Returns
Returns
-------
-------
float, float
float, float
"""
"""
biased_ema
=
biased_ema
*
decay
+
(
1
-
decay
)
*
value
biased_ema
=
biased_ema
*
decay
+
(
1
-
decay
)
*
value
unbiased_ema
=
biased_ema
/
(
1
-
decay
**
step
)
# Bias correction
return
biased_ema
return
biased_ema
,
unbiased_ema
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
def
update_quantization_param
(
bits
,
rmin
,
rmax
):
...
@@ -85,16 +82,10 @@ def update_quantization_param(bits, rmin, rmax):
...
@@ -85,16 +82,10 @@ def update_quantization_param(bits, rmin, rmax):
# extend the [min, max] interval to ensure that it contains 0.
# extend the [min, max] interval to ensure that it contains 0.
# Otherwise, we would not meet the requirement that 0 be an exactly
# Otherwise, we would not meet the requirement that 0 be an exactly
# representable value.
# representable value.
if
rmin
.
is_cuda
:
rmin
=
torch
.
min
(
rmin
,
torch
.
Tensor
([
0
]).
to
(
rmin
.
device
))
rmin
=
torch
.
min
(
rmin
,
torch
.
Tensor
([
0
]).
cuda
())
rmax
=
torch
.
max
(
rmax
,
torch
.
Tensor
([
0
]).
to
(
rmin
.
device
))
rmax
=
torch
.
max
(
rmax
,
torch
.
Tensor
([
0
]).
cuda
())
qmin
=
torch
.
Tensor
([
0
]).
to
(
rmin
.
device
)
qmin
=
torch
.
Tensor
([
0
]).
cuda
()
qmax
=
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
rmin
.
device
)
qmax
=
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
cuda
()
else
:
rmin
=
torch
.
min
(
rmin
,
torch
.
Tensor
([
0
]))
rmax
=
torch
.
max
(
rmax
,
torch
.
Tensor
([
0
]))
qmin
=
torch
.
Tensor
([
0
])
qmax
=
torch
.
Tensor
([(
1
<<
bits
)
-
1
])
# First determine the scale.
# First determine the scale.
scale
=
(
rmax
-
rmin
)
/
(
qmax
-
qmin
)
scale
=
(
rmax
-
rmin
)
/
(
qmax
-
qmin
)
...
@@ -103,7 +94,6 @@ def update_quantization_param(bits, rmin, rmax):
...
@@ -103,7 +94,6 @@ def update_quantization_param(bits, rmin, rmax):
initial_zero_point
=
qmin
-
rmin
/
scale
initial_zero_point
=
qmin
-
rmin
/
scale
# Now we need to nudge the zero point to be an integer
# Now we need to nudge the zero point to be an integer
nudged_zero_point
=
0
if
initial_zero_point
<
qmin
:
if
initial_zero_point
<
qmin
:
nudged_zero_point
=
qmin
nudged_zero_point
=
qmin
elif
initial_zero_point
>
qmax
:
elif
initial_zero_point
>
qmax
:
...
@@ -121,6 +111,15 @@ def get_bits_length(config, quant_type):
...
@@ -121,6 +111,15 @@ def get_bits_length(config, quant_type):
return
config
[
"quant_bits"
].
get
(
quant_type
)
return
config
[
"quant_bits"
].
get
(
quant_type
)
class
QATGrad
(
QuantGrad
):
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
class
QAT_Quantizer
(
Quantizer
):
class
QAT_Quantizer
(
Quantizer
):
"""Quantizer defined in:
"""Quantizer defined in:
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference
...
@@ -148,6 +147,7 @@ class QAT_Quantizer(Quantizer):
...
@@ -148,6 +147,7 @@ class QAT_Quantizer(Quantizer):
types of nn.module you want to apply quantization, eg. 'Conv2d'
types of nn.module you want to apply quantization, eg. 'Conv2d'
"""
"""
super
().
__init__
(
model
,
config_list
,
optimizer
)
super
().
__init__
(
model
,
config_list
,
optimizer
)
self
.
quant_grad
=
QATGrad
modules_to_compress
=
self
.
get_modules_to_compress
()
modules_to_compress
=
self
.
get_modules_to_compress
()
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
self
.
bound_model
.
register_buffer
(
"steps"
,
torch
.
Tensor
([
1
]))
for
layer
,
config
in
modules_to_compress
:
for
layer
,
config
in
modules_to_compress
:
...
@@ -199,10 +199,8 @@ class QAT_Quantizer(Quantizer):
...
@@ -199,10 +199,8 @@ class QAT_Quantizer(Quantizer):
-------
-------
Tensor
Tensor
"""
"""
if
real_val
.
is_cuda
:
op
.
zero_point
=
op
.
zero_point
.
to
(
real_val
.
device
)
op
.
zero_point
=
op
.
zero_point
.
cuda
()
op
.
scale
=
op
.
scale
.
to
(
real_val
.
device
)
op
.
scale
=
op
.
scale
.
cuda
()
transformed_val
=
op
.
zero_point
+
real_val
/
op
.
scale
transformed_val
=
op
.
zero_point
+
real_val
/
op
.
scale
qmin
=
0
qmin
=
0
qmax
=
(
1
<<
bits
)
-
1
qmax
=
(
1
<<
bits
)
-
1
...
@@ -269,16 +267,17 @@ class QAT_Quantizer(Quantizer):
...
@@ -269,16 +267,17 @@ class QAT_Quantizer(Quantizer):
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
assert
output_bits
>=
1
,
"quant bits length should be at least 1"
if
quant_start_step
>
self
.
bound_model
.
steps
:
if
quant_start_step
>
self
.
bound_model
.
steps
:
module
.
tracked_min_biased
,
module
.
tracked_max_biased
=
torch
.
min
(
output
),
torch
.
max
(
output
)
return
output
return
output
# we dont update output quantization parameters in evaluation stage
# we dont update output quantization parameters in evaluation stage
if
wrapper
.
training
:
if
wrapper
.
training
:
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
current_min
,
current_max
=
torch
.
min
(
output
),
torch
.
max
(
output
)
module
.
tracked_min_biased
,
module
.
tracked_min
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
tracked_min_biased
=
update_ema
(
module
.
tracked_min_biased
,
current_min
,
module
.
ema_decay
,
self
.
bound_model
.
steps
)
module
.
ema_decay
)
module
.
tracked_max_biased
,
module
.
tracked_max
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
tracked_max_biased
=
update_ema
(
module
.
tracked_max_biased
,
current_max
,
module
.
ema_decay
,
self
.
bound_model
.
steps
)
module
.
ema_decay
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
,
module
.
tracked_max
)
module
.
scale
,
module
.
zero_point
=
update_quantization_param
(
output_bits
,
module
.
tracked_min
_biased
,
module
.
tracked_max
_biased
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_quantize
(
output_bits
,
module
,
output
)
out
=
self
.
_dequantize
(
module
,
out
)
out
=
self
.
_dequantize
(
module
,
out
)
return
out
return
out
...
@@ -342,7 +341,7 @@ class DoReFaQuantizer(Quantizer):
...
@@ -342,7 +341,7 @@ class DoReFaQuantizer(Quantizer):
class
ClipGrad
(
QuantGrad
):
class
ClipGrad
(
QuantGrad
):
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
quant_type
):
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
if
quant_type
==
QuantType
.
QUANT_OUTPUT
:
if
quant_type
==
QuantType
.
QUANT_OUTPUT
:
grad_output
[
torch
.
abs
(
tensor
)
>
1
]
=
0
grad_output
[
torch
.
abs
(
tensor
)
>
1
]
=
0
return
grad_output
return
grad_output
...
...
nni/algorithms/nas/pytorch/cdarts/mutator.py
View file @
4784cc6c
...
@@ -132,7 +132,7 @@ class DartsDiscreteMutator(Mutator):
...
@@ -132,7 +132,7 @@ class DartsDiscreteMutator(Mutator):
----------
----------
model : nn.Module
model : nn.Module
The model to apply the mutator.
The model to apply the mutator.
parent_mutator : Mutator
parent_mutator :
nni.nas.pytorch.mutator.
Mutator
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
The mutator that provides ``sample_final`` method, that will be called to get the architecture.
"""
"""
def
__init__
(
self
,
model
,
parent_mutator
):
def
__init__
(
self
,
model
,
parent_mutator
):
...
...
nni/algorithms/nas/pytorch/spos/trainer.py
View file @
4784cc6c
...
@@ -20,7 +20,7 @@ class SPOSSupernetTrainer(Trainer):
...
@@ -20,7 +20,7 @@ class SPOSSupernetTrainer(Trainer):
----------
----------
model : nn.Module
model : nn.Module
Model with mutables.
Model with mutables.
mutator : Mutator
mutator :
nni.nas.pytorch.mutator.
Mutator
A mutator object that has been initialized with the model.
A mutator object that has been initialized with the model.
loss : callable
loss : callable
Called with logits and targets. Returns a loss tensor.
Called with logits and targets. Returns a loss tensor.
...
...
nni/compression/pytorch/compressor.py
View file @
4784cc6c
...
@@ -580,10 +580,15 @@ class QuantType:
...
@@ -580,10 +580,15 @@ class QuantType:
"""
"""
Enum class for quantization type.
Enum class for quantization type.
"""
"""
QUANT_INPUT
=
'input'
QUANT_INPUT
=
0
QUANT_WEIGHT
=
'weight'
QUANT_WEIGHT
=
1
QUANT_OUTPUT
=
'output'
QUANT_OUTPUT
=
2
QType_Dict
=
{
0
:
"input"
,
1
:
"weight"
,
2
:
"output"
}
class
QuantGrad
(
torch
.
autograd
.
Function
):
class
QuantGrad
(
torch
.
autograd
.
Function
):
"""
"""
...
@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function):
...
@@ -628,7 +633,7 @@ class QuantGrad(torch.autograd.Function):
return
config
[
"quant_bits"
].
get
(
quant_type
)
return
config
[
"quant_bits"
].
get
(
quant_type
)
@
staticmethod
@
staticmethod
def
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
):
def
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
):
"""
"""
This method should be overrided by subclass to provide customized backward function,
This method should be overrided by subclass to provide customized backward function,
default implementation is Straight-Through Estimator
default implementation is Straight-Through Estimator
...
@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function):
...
@@ -652,9 +657,6 @@ class QuantGrad(torch.autograd.Function):
tensor
tensor
gradient of the input of quantization operation
gradient of the input of quantization operation
"""
"""
tensor_q
=
QuantGrad
.
_quantize
(
tensor
,
scale
,
zero_point
)
mask
=
(
tensor_q
<
qmin
)
|
(
tensor_q
>
qmax
)
grad_output
[
mask
]
=
0
return
grad_output
return
grad_output
@
staticmethod
@
staticmethod
...
@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function):
...
@@ -668,15 +670,21 @@ class QuantGrad(torch.autograd.Function):
else
:
else
:
raise
ValueError
(
"unrecognized QuantType."
)
raise
ValueError
(
"unrecognized QuantType."
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
quant_type
)
qmin
,
qmax
=
torch
.
Tensor
([
0
],
device
=
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
],
device
=
tensor
.
device
)
bits
=
QuantGrad
.
get_bits_length
(
wrapper
.
config
,
QType_Dict
[
quant_type
])
ctx
.
save_for_backward
(
tensor
,
wrapper
.
module
.
scale
,
wrapper
.
module
.
zero_point
,
qmin
,
qmax
)
qmin
,
qmax
=
torch
.
Tensor
([
0
]).
to
(
tensor
.
device
),
torch
.
Tensor
([(
1
<<
bits
)
-
1
]).
to
(
tensor
.
device
)
if
hasattr
(
wrapper
.
module
,
'scale'
)
and
hasattr
(
wrapper
.
module
,
'zero_point'
):
scale
=
wrapper
.
module
.
scale
zero_point
=
wrapper
.
module
.
zero_point
else
:
scale
,
zero_point
=
None
,
None
ctx
.
save_for_backward
(
tensor
,
torch
.
Tensor
([
quant_type
]),
scale
,
zero_point
,
qmin
,
qmax
)
return
output
return
output
@
classmethod
@
classmethod
def
backward
(
cls
,
ctx
,
grad_output
):
def
backward
(
cls
,
ctx
,
grad_output
):
tensor
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
tensor
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
=
ctx
.
saved_variables
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
scale
,
zero_point
,
qmin
,
qmax
)
output
=
cls
.
quant_backward
(
tensor
,
grad_output
,
quant_type
,
scale
,
zero_point
,
qmin
,
qmax
)
return
output
,
None
,
None
,
None
return
output
,
None
,
None
,
None
def
_check_weight
(
module
):
def
_check_weight
(
module
):
...
...
nni/compression/pytorch/speedup/infer_shape.py
View file @
4784cc6c
...
@@ -273,7 +273,8 @@ infer_from_inshape = {
...
@@ -273,7 +273,8 @@ infer_from_inshape = {
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_inshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
),
'aten::detach'
:
lambda
module_masks
,
mask
:
dropout_inshape
(
module_masks
,
mask
)
}
}
"""
"""
...
@@ -308,7 +309,8 @@ infer_from_outshape = {
...
@@ -308,7 +309,8 @@ infer_from_outshape = {
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_outshape
(
module_masks
,
mask
,
shape
),
'aten::mean'
:
lambda
module_masks
,
mask
,
shape
:
mean_outshape
(
module_masks
,
mask
,
shape
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'Dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'Dropout2d'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
'aten::dropout'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
),
'aten::detach'
:
lambda
module_masks
,
mask
:
dropout_outshape
(
module_masks
,
mask
)
}
}
...
@@ -889,23 +891,18 @@ def conv2d_mask(module_masks, mask):
...
@@ -889,23 +891,18 @@ def conv2d_mask(module_masks, mask):
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
sum_idx
=
(
1
,
2
,
3
)
if
dim
==
0
else
(
0
,
2
,
3
)
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
index
=
torch
.
nonzero
(
weight_mask
.
abs
().
sum
(
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
sum_idx
)
!=
0
,
as_tuple
=
True
)[
0
]
if
len
(
index
)
==
weight_mask
.
shape
[
dim
]:
# full mask
index
=
None
if
index
is
None
:
index
=
index
.
long
().
to
(
weight_mask
.
device
)
return
None
,
None
,
None
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
else
:
weight_cmask
.
add_index_mask
(
dim
=
dim
,
index
=
index
)
index
=
index
.
long
().
to
(
weight_mask
.
device
)
bias_cmask
=
None
weight_cmask
=
CoarseMask
(
num_dim
=
4
)
if
dim
==
0
and
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
weight_cmask
.
add_index_mask
(
dim
=
dim
,
index
=
index
)
bias_index
=
torch
.
nonzero
(
mask
[
'bias'
],
as_tuple
=
True
)[
0
]
bias_cmask
=
None
assert
torch
.
all
(
torch
.
eq
(
index
,
bias_index
)),
\
if
dim
==
0
and
'bias'
in
mask
and
mask
[
'bias'
]
is
not
None
:
"bias mask should be consistent with weight mask"
bias_index
=
torch
.
nonzero
(
mask
[
'bias'
],
as_tuple
=
True
)[
0
]
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
assert
torch
.
all
(
torch
.
eq
(
index
,
bias_index
)),
\
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
"bias mask should be consistent with weight mask"
return
index
,
weight_cmask
,
bias_cmask
bias_cmask
=
CoarseMask
(
num_dim
=
1
)
bias_cmask
.
add_index_mask
(
dim
=
0
,
index
=
bias_index
)
return
index
,
weight_cmask
,
bias_cmask
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
index
,
weight_cmask
,
bias_cmask
=
convert_to_coarse_mask
(
mask
,
dim
=
conv_prune_dim
)
mask
,
dim
=
conv_prune_dim
)
...
@@ -960,6 +957,7 @@ def conv2d_inshape(module_masks, mask):
...
@@ -960,6 +957,7 @@ def conv2d_inshape(module_masks, mask):
# the same conv layer may be accessed more
# the same conv layer may be accessed more
# than once, such as a concat operation.
# than once, such as a concat operation.
# mask conflict should be solved by fix_mask_conflict before speedup
# mask conflict should be solved by fix_mask_conflict before speedup
assert
module_masks
.
input_mask
==
mask
assert
module_masks
.
input_mask
==
mask
# shape changes pass through depths wise conv layers
# shape changes pass through depths wise conv layers
...
...
nni/compression/pytorch/utils/mask_conflict.py
View file @
4784cc6c
...
@@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
...
@@ -31,6 +31,7 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# if the input is the path of the mask_file
# if the input is the path of the mask_file
assert
os
.
path
.
exists
(
masks
)
assert
os
.
path
.
exists
(
masks
)
masks
=
torch
.
load
(
masks
)
masks
=
torch
.
load
(
masks
)
assert
len
(
masks
)
>
0
,
'Mask tensor cannot be empty'
# if the user uses the model and dummy_input to trace the model, we
# if the user uses the model and dummy_input to trace the model, we
# should get the traced model handly, so that, we only trace the
# should get the traced model handly, so that, we only trace the
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
# model once, GroupMaskConflict and ChannelMaskConflict will reuse
...
@@ -127,6 +128,7 @@ class CatMaskPadding(MaskFix):
...
@@ -127,6 +128,7 @@ class CatMaskPadding(MaskFix):
for
layer
in
layers
:
for
layer
in
layers
:
if
layer
in
self
.
masks
:
if
layer
in
self
.
masks
:
continue
continue
module
=
name_to_module
[
layer
]
module
=
name_to_module
[
layer
]
w_shape
=
module
.
weight
.
data
.
size
()
w_shape
=
module
.
weight
.
data
.
size
()
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
w_mask
=
torch
.
ones
(
w_shape
).
to
(
device
)
...
@@ -136,6 +138,7 @@ class CatMaskPadding(MaskFix):
...
@@ -136,6 +138,7 @@ class CatMaskPadding(MaskFix):
b_shape
=
module
.
bias
.
data
.
size
()
b_shape
=
module
.
bias
.
data
.
size
()
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
b_mask
=
torch
.
ones
(
b_shape
).
to
(
device
)
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
self
.
masks
[
layer
]
=
{
'weight'
:
w_mask
,
'bias'
:
b_mask
}
return
self
.
masks
return
self
.
masks
...
@@ -250,6 +253,10 @@ class ChannelMaskConflict(MaskFix):
...
@@ -250,6 +253,10 @@ class ChannelMaskConflict(MaskFix):
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
self
.
model
,
self
.
dummy_input
,
self
.
traced
)
depen_sets
=
channel_depen
.
dependency_sets
depen_sets
=
channel_depen
.
dependency_sets
sum_idx
=
(
1
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
0
,
2
,
3
)
sum_idx
=
(
1
,
2
,
3
)
if
self
.
conv_prune_dim
==
0
else
(
0
,
2
,
3
)
(
_tmp_name
,
_tmp_tensor
)
=
list
(
self
.
masks
.
items
())[
0
]
device
=
_tmp_tensor
[
'weight'
].
device
for
dset
in
depen_sets
:
for
dset
in
depen_sets
:
if
len
(
dset
)
<=
1
:
if
len
(
dset
)
<=
1
:
continue
continue
...
@@ -301,7 +308,7 @@ class ChannelMaskConflict(MaskFix):
...
@@ -301,7 +308,7 @@ class ChannelMaskConflict(MaskFix):
for
i
,
dim_mask
in
enumerate
(
channel_masks
):
for
i
,
dim_mask
in
enumerate
(
channel_masks
):
if
dim_mask
is
None
:
if
dim_mask
is
None
:
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
()
channel_masks
[
i
]
=
torch
.
ones
(
num_channels
).
int
()
.
to
(
device
)
# merge masks with 'or'
# merge masks with 'or'
merged_channel_mask
=
channel_masks
[
0
].
clone
()
merged_channel_mask
=
channel_masks
[
0
].
clone
()
...
...
nni/experiment/config/common.py
View file @
4784cc6c
...
@@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase):
...
@@ -65,15 +65,19 @@ class ExperimentConfig(ConfigBase):
tuner
:
Optional
[
_AlgorithmConfig
]
=
None
tuner
:
Optional
[
_AlgorithmConfig
]
=
None
accessor
:
Optional
[
_AlgorithmConfig
]
=
None
accessor
:
Optional
[
_AlgorithmConfig
]
=
None
advisor
:
Optional
[
_AlgorithmConfig
]
=
None
advisor
:
Optional
[
_AlgorithmConfig
]
=
None
training_service
:
TrainingServiceConfig
training_service
:
Union
[
TrainingServiceConfig
,
List
[
TrainingServiceConfig
]]
def
__init__
(
self
,
training_service_platform
:
Optional
[
str
]
=
None
,
**
kwargs
):
def
__init__
(
self
,
training_service_platform
:
Optional
[
Union
[
str
,
List
[
str
]]
]
=
None
,
**
kwargs
):
kwargs
=
util
.
case_insensitive
(
kwargs
)
kwargs
=
util
.
case_insensitive
(
kwargs
)
if
training_service_platform
is
not
None
:
if
training_service_platform
is
not
None
:
assert
'trainingservice'
not
in
kwargs
assert
'trainingservice'
not
in
kwargs
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
training_service_platform
)
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
platform
=
training_service_platform
)
elif
isinstance
(
kwargs
.
get
(
'trainingservice'
),
dict
):
elif
isinstance
(
kwargs
.
get
(
'trainingservice'
),
(
dict
,
list
)):
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
**
kwargs
[
'trainingservice'
])
# dict means a single training service
# list means hybrid training service
kwargs
[
'trainingservice'
]
=
util
.
training_service_config_factory
(
config
=
kwargs
[
'trainingservice'
])
else
:
raise
RuntimeError
(
'Unsupported Training service configuration!'
)
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
def
validate
(
self
,
initialized_tuner
:
bool
=
False
)
->
None
:
...
...
nni/experiment/config/convert.py
View file @
4784cc6c
...
@@ -18,8 +18,29 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
...
@@ -18,8 +18,29 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data
=
config
.
json
()
data
=
config
.
json
()
ts
=
data
.
pop
(
'trainingService'
)
ts
=
data
.
pop
(
'trainingService'
)
if
ts
[
'platform'
]
==
'openpai'
:
ts
[
'platform'
]
=
'pai'
data
[
'trial'
]
=
{
'command'
:
data
.
pop
(
'trialCommand'
),
'codeDir'
:
data
.
pop
(
'trialCodeDirectory'
),
}
if
'trialGpuNumber'
in
data
:
data
[
'trial'
][
'gpuNum'
]
=
data
.
pop
(
'trialGpuNumber'
)
if
isinstance
(
ts
,
list
):
hybrid_names
=
[]
for
conf
in
ts
:
if
conf
[
'platform'
]
==
'openpai'
:
conf
[
'platform'
]
=
'pai'
hybrid_names
.
append
(
conf
[
'platform'
])
_handle_training_service
(
conf
,
data
)
data
[
'trainingServicePlatform'
]
=
'hybrid'
data
[
'hybridConfig'
]
=
{
'trainingServicePlatforms'
:
hybrid_names
}
else
:
if
ts
[
'platform'
]
==
'openpai'
:
ts
[
'platform'
]
=
'pai'
data
[
'trainingServicePlatform'
]
=
ts
[
'platform'
]
_handle_training_service
(
ts
,
data
)
data
[
'authorName'
]
=
'N/A'
data
[
'authorName'
]
=
'N/A'
data
[
'experimentName'
]
=
data
.
get
(
'experimentName'
,
'N/A'
)
data
[
'experimentName'
]
=
data
.
get
(
'experimentName'
,
'N/A'
)
...
@@ -27,7 +48,7 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
...
@@ -27,7 +48,7 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if
data
[
'debug'
]:
if
data
[
'debug'
]:
data
[
'versionCheck'
]
=
False
data
[
'versionCheck'
]
=
False
data
[
'maxTrialNum'
]
=
data
.
pop
(
'maxTrialNumber'
,
99999
)
data
[
'maxTrialNum'
]
=
data
.
pop
(
'maxTrialNumber'
,
99999
)
data
[
'trainingServicePlatform'
]
=
ts
[
'platform'
]
ss
=
data
.
pop
(
'searchSpace'
,
None
)
ss
=
data
.
pop
(
'searchSpace'
,
None
)
ss_file
=
data
.
pop
(
'searchSpaceFile'
,
None
)
ss_file
=
data
.
pop
(
'searchSpaceFile'
,
None
)
if
ss
is
not
None
:
if
ss
is
not
None
:
...
@@ -58,14 +79,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
...
@@ -58,14 +79,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
if
tuner_gpu_indices
is
not
None
:
if
tuner_gpu_indices
is
not
None
:
data
[
'tuner'
][
'gpuIndicies'
]
=
tuner_gpu_indices
data
[
'tuner'
][
'gpuIndicies'
]
=
tuner_gpu_indices
data
[
'trial'
]
=
{
return
data
'command'
:
data
.
pop
(
'trialCommand'
),
'codeDir'
:
data
.
pop
(
'trialCodeDirectory'
),
}
if
'trialGpuNumber'
in
data
:
data
[
'trial'
][
'gpuNum'
]
=
data
.
pop
(
'trialGpuNumber'
)
def
_handle_training_service
(
ts
,
data
):
if
ts
[
'platform'
]
==
'local'
:
if
ts
[
'platform'
]
==
'local'
:
data
[
'localConfig'
]
=
{
data
[
'localConfig'
]
=
{
'useActiveGpu'
:
ts
.
get
(
'useActiveGpu'
,
False
),
'useActiveGpu'
:
ts
.
get
(
'useActiveGpu'
,
False
),
...
@@ -98,6 +114,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
...
@@ -98,6 +114,9 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
data
[
'trial'
][
'image'
]
=
ts
[
'dockerImage'
]
data
[
'trial'
][
'image'
]
=
ts
[
'dockerImage'
]
data
[
'trial'
][
'nniManagerNFSMountPath'
]
=
ts
[
'localStorageMountPoint'
]
data
[
'trial'
][
'nniManagerNFSMountPath'
]
=
ts
[
'localStorageMountPoint'
]
data
[
'trial'
][
'containerNFSMountPath'
]
=
ts
[
'containerStorageMountPoint'
]
data
[
'trial'
][
'containerNFSMountPath'
]
=
ts
[
'containerStorageMountPoint'
]
data
[
'trial'
][
'paiStorageConfigName'
]
=
ts
[
'storageConfigName'
]
data
[
'trial'
][
'cpuNum'
]
=
ts
[
'trialCpuNumber'
]
data
[
'trial'
][
'memoryMB'
]
=
ts
[
'trialMemorySize'
]
data
[
'paiConfig'
]
=
{
data
[
'paiConfig'
]
=
{
'userName'
:
ts
[
'username'
],
'userName'
:
ts
[
'username'
],
'token'
:
ts
[
'token'
],
'token'
:
ts
[
'token'
],
...
@@ -140,8 +159,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
...
@@ -140,8 +159,6 @@ def to_v1_yaml(config: ExperimentConfig, skip_nnictl: bool = False) -> Dict[str,
elif
ts
[
'platform'
]
==
'adl'
:
elif
ts
[
'platform'
]
==
'adl'
:
data
[
'trial'
][
'image'
]
=
ts
[
'dockerImage'
]
data
[
'trial'
][
'image'
]
=
ts
[
'dockerImage'
]
return
data
def
_convert_gpu_indices
(
indices
):
def
_convert_gpu_indices
(
indices
):
return
','
.
join
(
str
(
idx
)
for
idx
in
indices
)
if
indices
is
not
None
else
None
return
','
.
join
(
str
(
idx
)
for
idx
in
indices
)
if
indices
is
not
None
else
None
...
@@ -175,19 +192,34 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
...
@@ -175,19 +192,34 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
ret
=
[]
ret
=
[]
if
config
.
training_service
.
platform
==
'local'
:
if
isinstance
(
config
.
training_service
,
list
):
hybrid_conf
=
dict
()
hybrid_conf
[
'hybrid_config'
]
=
experiment_config
[
'hybridConfig'
]
for
conf
in
config
.
training_service
:
metadata
=
_get_cluster_metadata
(
conf
.
platform
,
experiment_config
)
if
metadata
is
not
None
:
hybrid_conf
.
update
(
metadata
)
ret
.
append
(
hybrid_conf
)
else
:
metadata
=
_get_cluster_metadata
(
config
.
training_service
.
platform
,
experiment_config
)
if
metadata
is
not
None
:
ret
.
append
(
metadata
)
if
experiment_config
.
get
(
'nniManagerIp'
)
is
not
None
:
ret
.
append
({
'nni_manager_ip'
:
{
'nniManagerIp'
:
experiment_config
[
'nniManagerIp'
]}})
ret
.
append
({
'trial_config'
:
experiment_config
[
'trial'
]})
return
ret
def
_get_cluster_metadata
(
platform
:
str
,
experiment_config
)
->
Dict
:
if
platform
==
'local'
:
request_data
=
dict
()
request_data
=
dict
()
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
request_data
[
'local_config'
]
=
experiment_config
[
'localConfig'
]
if
request_data
[
'local_config'
]:
if
request_data
[
'local_config'
]:
if
request_data
[
'local_config'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
),
int
):
if
request_data
[
'local_config'
].
get
(
'gpuIndices'
)
and
isinstance
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
),
int
):
request_data
[
'local_config'
][
'gpuIndices'
]
=
str
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
))
request_data
[
'local_config'
][
'gpuIndices'
]
=
str
(
request_data
[
'local_config'
].
get
(
'gpuIndices'
))
if
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
):
return
request_data
request_data
[
'local_config'
][
'maxTrialNumOnEachGpu'
]
=
request_data
[
'local_config'
].
get
(
'maxTrialNumOnEachGpu'
)
if
request_data
[
'local_config'
].
get
(
'useActiveGpu'
):
request_data
[
'local_config'
][
'useActiveGpu'
]
=
request_data
[
'local_config'
].
get
(
'useActiveGpu'
)
ret
.
append
(
request_data
)
elif
config
.
training_service
.
platform
==
'remote'
:
elif
platform
==
'remote'
:
request_data
=
dict
()
request_data
=
dict
()
if
experiment_config
.
get
(
'remoteConfig'
):
if
experiment_config
.
get
(
'remoteConfig'
):
request_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
request_data
[
'remote_config'
]
=
experiment_config
[
'remoteConfig'
]
...
@@ -198,31 +230,25 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
...
@@ -198,31 +230,25 @@ def to_cluster_metadata(config: ExperimentConfig) -> List[Dict[str, Any]]:
for
i
in
range
(
len
(
request_data
[
'machine_list'
])):
for
i
in
range
(
len
(
request_data
[
'machine_list'
])):
if
isinstance
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
),
int
):
if
isinstance
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
),
int
):
request_data
[
'machine_list'
][
i
][
'gpuIndices'
]
=
str
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
))
request_data
[
'machine_list'
][
i
][
'gpuIndices'
]
=
str
(
request_data
[
'machine_list'
][
i
].
get
(
'gpuIndices'
))
ret
.
append
(
request_data
)
ret
urn
request_data
elif
config
.
training_service
.
platform
==
'openpai'
:
elif
platform
==
'openpai'
:
ret
.
append
(
{
'pai_config'
:
experiment_config
[
'paiConfig'
]}
)
ret
urn
{
'pai_config'
:
experiment_config
[
'paiConfig'
]}
elif
config
.
training_service
.
platform
==
'aml'
:
elif
platform
==
'aml'
:
ret
.
append
(
{
'aml_config'
:
experiment_config
[
'amlConfig'
]}
)
ret
urn
{
'aml_config'
:
experiment_config
[
'amlConfig'
]}
elif
config
.
training_service
.
platform
==
'kubeflow'
:
elif
platform
==
'kubeflow'
:
ret
.
append
(
{
'kubeflow_config'
:
experiment_config
[
'kubeflowConfig'
]}
)
ret
urn
{
'kubeflow_config'
:
experiment_config
[
'kubeflowConfig'
]}
elif
config
.
training_service
.
platform
==
'frameworkcontroller'
:
elif
platform
==
'frameworkcontroller'
:
ret
.
append
(
{
'frameworkcontroller_config'
:
experiment_config
[
'frameworkcontrollerConfig'
]}
)
ret
urn
{
'frameworkcontroller_config'
:
experiment_config
[
'frameworkcontrollerConfig'
]}
elif
config
.
training_service
.
platform
==
'adl'
:
elif
platform
==
'adl'
:
pass
return
None
else
:
else
:
raise
RuntimeError
(
'Unsupported training service '
+
config
.
training_service
.
platform
)
raise
RuntimeError
(
'Unsupported training service '
+
platform
)
if
experiment_config
.
get
(
'nniManagerIp'
)
is
not
None
:
ret
.
append
({
'nni_manager_ip'
:
{
'nniManagerIp'
:
experiment_config
[
'nniManagerIp'
]}})
ret
.
append
({
'trial_config'
:
experiment_config
[
'trial'
]})
return
ret
def
to_rest_json
(
config
:
ExperimentConfig
)
->
Dict
[
str
,
Any
]:
def
to_rest_json
(
config
:
ExperimentConfig
)
->
Dict
[
str
,
Any
]:
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
experiment_config
=
to_v1_yaml
(
config
,
skip_nnictl
=
True
)
...
...
nni/experiment/config/openpai.py
View file @
4784cc6c
...
@@ -2,7 +2,7 @@
...
@@ -2,7 +2,7 @@
# Licensed under the MIT license.
# Licensed under the MIT license.
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
,
PurePosixPath
from
typing
import
Any
,
Dict
,
Optional
from
typing
import
Any
,
Dict
,
Optional
from
.base
import
PathLike
from
.base
import
PathLike
...
@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig):
...
@@ -17,6 +17,9 @@ class OpenpaiConfig(TrainingServiceConfig):
host
:
str
host
:
str
username
:
str
username
:
str
token
:
str
token
:
str
trial_cpu_number
:
int
trial_memory_size
:
str
storage_config_name
:
str
docker_image
:
str
=
'msranni/nni:latest'
docker_image
:
str
=
'msranni/nni:latest'
local_storage_mount_point
:
PathLike
local_storage_mount_point
:
PathLike
container_storage_mount_point
:
str
container_storage_mount_point
:
str
...
@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig):
...
@@ -34,7 +37,7 @@ class OpenpaiConfig(TrainingServiceConfig):
_validation_rules
=
{
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'openpai'
,
'cannot be modified'
),
'platform'
:
lambda
value
:
(
value
==
'openpai'
,
'cannot be modified'
),
'local_storage_mount_point'
:
lambda
value
:
Path
(
value
).
is_dir
(),
'local_storage_mount_point'
:
lambda
value
:
Path
(
value
).
is_dir
(),
'container_storage_mount_point'
:
lambda
value
:
(
Path
(
value
).
is_absolute
(),
'is not absolute'
),
'container_storage_mount_point'
:
lambda
value
:
(
PurePosix
Path
(
value
).
is_absolute
(),
'is not absolute'
),
'openpai_config_file'
:
lambda
value
:
Path
(
value
).
is_file
()
'openpai_config_file'
:
lambda
value
:
Path
(
value
).
is_file
()
}
}
...
...
nni/experiment/config/remote.py
View file @
4784cc6c
...
@@ -4,6 +4,7 @@
...
@@ -4,6 +4,7 @@
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
List
,
Optional
,
Union
from
typing
import
List
,
Optional
,
Union
import
warnings
from
.base
import
ConfigBase
,
PathLike
from
.base
import
ConfigBase
,
PathLike
from
.common
import
TrainingServiceConfig
from
.common
import
TrainingServiceConfig
...
@@ -17,7 +18,7 @@ class RemoteMachineConfig(ConfigBase):
...
@@ -17,7 +18,7 @@ class RemoteMachineConfig(ConfigBase):
port
:
int
=
22
port
:
int
=
22
user
:
str
user
:
str
password
:
Optional
[
str
]
=
None
password
:
Optional
[
str
]
=
None
ssh_key_file
:
Optional
[
PathLike
]
=
None
ssh_key_file
:
PathLike
=
None
#'~/.ssh/id_rsa'
ssh_passphrase
:
Optional
[
str
]
=
None
ssh_passphrase
:
Optional
[
str
]
=
None
use_active_gpu
:
bool
=
False
use_active_gpu
:
bool
=
False
max_trial_number_per_gpu
:
int
=
1
max_trial_number_per_gpu
:
int
=
1
...
@@ -39,6 +40,8 @@ class RemoteMachineConfig(ConfigBase):
...
@@ -39,6 +40,8 @@ class RemoteMachineConfig(ConfigBase):
super
().
validate
()
super
().
validate
()
if
self
.
password
is
None
and
not
Path
(
self
.
ssh_key_file
).
is_file
():
if
self
.
password
is
None
and
not
Path
(
self
.
ssh_key_file
).
is_file
():
raise
ValueError
(
f
'Password is not provided and cannot find SSH key file "
{
self
.
ssh_key_file
}
"'
)
raise
ValueError
(
f
'Password is not provided and cannot find SSH key file "
{
self
.
ssh_key_file
}
"'
)
if
self
.
password
:
warnings
.
warn
(
'Password will be exposed through web UI in plain text. We recommend to use SSH key file.'
)
@
dataclass
(
init
=
False
)
@
dataclass
(
init
=
False
)
class
RemoteConfig
(
TrainingServiceConfig
):
class
RemoteConfig
(
TrainingServiceConfig
):
...
@@ -51,6 +54,10 @@ class RemoteConfig(TrainingServiceConfig):
...
@@ -51,6 +54,10 @@ class RemoteConfig(TrainingServiceConfig):
kwargs
[
'machinelist'
]
=
util
.
load_config
(
RemoteMachineConfig
,
kwargs
.
get
(
'machinelist'
))
kwargs
[
'machinelist'
]
=
util
.
load_config
(
RemoteMachineConfig
,
kwargs
.
get
(
'machinelist'
))
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
_canonical_rules
=
{
'machine_list'
:
lambda
value
:
[
config
.
canonical
()
for
config
in
value
]
}
_validation_rules
=
{
_validation_rules
=
{
'platform'
:
lambda
value
:
(
value
==
'remote'
,
'cannot be modified'
)
'platform'
:
lambda
value
:
(
value
==
'remote'
,
'cannot be modified'
)
}
}
nni/experiment/config/util.py
View file @
4784cc6c
...
@@ -8,7 +8,7 @@ Miscellaneous utility functions.
...
@@ -8,7 +8,7 @@ Miscellaneous utility functions.
import
math
import
math
import
os.path
import
os.path
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
typing
import
Any
,
Dict
,
Optional
,
Union
,
List
PathLike
=
Union
[
Path
,
str
]
PathLike
=
Union
[
Path
,
str
]
...
@@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
...
@@ -29,12 +29,26 @@ def canonical_path(path: Optional[PathLike]) -> Optional[str]:
def
count
(
*
values
)
->
int
:
def
count
(
*
values
)
->
int
:
return
sum
(
value
is
not
None
and
value
is
not
False
for
value
in
values
)
return
sum
(
value
is
not
None
and
value
is
not
False
for
value
in
values
)
def
training_service_config_factory
(
platform
:
str
,
**
kwargs
):
# -> TrainingServiceConfig
def
training_service_config_factory
(
platform
:
Union
[
str
,
List
[
str
]]
=
None
,
config
:
Union
[
List
,
Dict
]
=
None
):
# -> TrainingServiceConfig
from
.common
import
TrainingServiceConfig
from
.common
import
TrainingServiceConfig
for
cls
in
TrainingServiceConfig
.
__subclasses__
():
ts_configs
=
[]
if
cls
.
platform
==
platform
:
if
platform
is
not
None
:
return
cls
(
**
kwargs
)
assert
config
is
None
raise
ValueError
(
f
'Unrecognized platform
{
platform
}
'
)
platforms
=
platform
if
isinstance
(
platform
,
list
)
else
[
platform
]
for
cls
in
TrainingServiceConfig
.
__subclasses__
():
if
cls
.
platform
in
platforms
:
ts_configs
.
append
(
cls
())
if
len
(
ts_configs
)
<
len
(
platforms
):
raise
RuntimeError
(
'There is unrecognized platform!'
)
else
:
assert
config
is
not
None
supported_platforms
=
{
cls
.
platform
:
cls
for
cls
in
TrainingServiceConfig
.
__subclasses__
()}
configs
=
config
if
isinstance
(
config
,
list
)
else
[
config
]
for
conf
in
configs
:
if
conf
[
'platform'
]
not
in
supported_platforms
:
raise
RuntimeError
(
f
'Unrecognized platform
{
conf
[
"platform"
]
}
'
)
ts_configs
.
append
(
supported_platforms
[
conf
[
'platform'
]](
**
conf
))
return
ts_configs
if
len
(
ts_configs
)
>
1
else
ts_configs
[
0
]
def
load_config
(
Type
,
value
):
def
load_config
(
Type
,
value
):
if
isinstance
(
value
,
list
):
if
isinstance
(
value
,
list
):
...
...
nni/experiment/experiment.py
View file @
4784cc6c
import
atexit
import
atexit
import
logging
import
logging
from
pathlib
import
Path
import
socket
import
socket
from
subprocess
import
Popen
from
subprocess
import
Popen
from
threading
import
Thread
from
threading
import
Thread
import
time
import
time
from
typing
import
Optional
,
overload
from
typing
import
Optional
,
Union
,
List
,
overload
import
colorama
import
colorama
import
psutil
import
psutil
...
@@ -15,8 +16,10 @@ from nni.tuner import Tuner
...
@@ -15,8 +16,10 @@ from nni.tuner import Tuner
from
.config
import
ExperimentConfig
from
.config
import
ExperimentConfig
from
.
import
launcher
from
.
import
launcher
from
.
import
management
from
.pipe
import
Pipe
from
.pipe
import
Pipe
from
.
import
rest
from
.
import
rest
from
..tools.nnictl.command_utils
import
kill_command
nni
.
runtime
.
log
.
init_logger_experiment
()
nni
.
runtime
.
log
.
init_logger_experiment
()
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
...
@@ -51,7 +54,7 @@ class Experiment:
...
@@ -51,7 +54,7 @@ class Experiment:
...
...
@
overload
@
overload
def
__init__
(
self
,
tuner
:
Tuner
,
training_service
:
str
)
->
None
:
def
__init__
(
self
,
tuner
:
Tuner
,
training_service
:
Union
[
str
,
List
[
str
]]
)
->
None
:
"""
"""
Prepare an experiment, leaving configuration fields to be set later.
Prepare an experiment, leaving configuration fields to be set later.
...
@@ -69,12 +72,13 @@ class Experiment:
...
@@ -69,12 +72,13 @@ class Experiment:
A tuner instance.
A tuner instance.
training_service
training_service
Name of training service.
Name of training service.
Supported value: "local", "remote", "openpai".
Supported value: "local", "remote", "openpai"
, "aml", "kubeflow", "frameworkcontroller", "adl" and hybrid training service
.
"""
"""
...
...
def
__init__
(
self
,
tuner
:
Tuner
,
config
=
None
,
training_service
=
None
):
def
__init__
(
self
,
tuner
:
Tuner
,
config
=
None
,
training_service
=
None
):
self
.
config
:
ExperimentConfig
self
.
config
:
ExperimentConfig
self
.
id
:
Optional
[
str
]
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
port
:
Optional
[
int
]
=
None
self
.
tuner
:
Tuner
=
tuner
self
.
tuner
:
Tuner
=
tuner
self
.
_proc
:
Optional
[
Popen
]
=
None
self
.
_proc
:
Optional
[
Popen
]
=
None
...
@@ -82,7 +86,7 @@ class Experiment:
...
@@ -82,7 +86,7 @@ class Experiment:
self
.
_dispatcher
:
Optional
[
MsgDispatcher
]
=
None
self
.
_dispatcher
:
Optional
[
MsgDispatcher
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
self
.
_dispatcher_thread
:
Optional
[
Thread
]
=
None
if
isinstance
(
config
,
str
):
if
isinstance
(
config
,
(
str
,
list
)
):
config
,
training_service
=
None
,
config
config
,
training_service
=
None
,
config
if
config
is
None
:
if
config
is
None
:
...
@@ -107,10 +111,15 @@ class Experiment:
...
@@ -107,10 +111,15 @@ class Experiment:
"""
"""
atexit
.
register
(
self
.
stop
)
atexit
.
register
(
self
.
stop
)
if
debug
:
self
.
id
=
management
.
generate_experiment_id
()
logging
.
getLogger
(
'nni'
).
setLevel
(
logging
.
DEBUG
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
self
.
config
,
port
,
debug
)
if
self
.
config
.
experiment_working_directory
is
not
None
:
log_dir
=
Path
(
self
.
config
.
experiment_working_directory
,
self
.
id
,
'log'
)
else
:
log_dir
=
Path
.
home
()
/
f
'nni-experiments/
{
self
.
id
}
/log'
nni
.
runtime
.
log
.
start_experiment_log
(
self
.
id
,
log_dir
,
debug
)
self
.
_proc
,
self
.
_pipe
=
launcher
.
start_experiment
(
self
.
id
,
self
.
config
,
port
,
debug
)
assert
self
.
_proc
is
not
None
assert
self
.
_proc
is
not
None
assert
self
.
_pipe
is
not
None
assert
self
.
_pipe
is
not
None
...
@@ -118,7 +127,7 @@ class Experiment:
...
@@ -118,7 +127,7 @@ class Experiment:
# dispatcher must be launched after pipe initialized
# dispatcher must be launched after pipe initialized
# the logic to launch dispatcher in background should be refactored into dispatcher api
# the logic to launch dispatcher in background should be refactored into dispatcher api
self
.
_dispatcher
=
MsgDispatcher
(
self
.
tuner
,
None
)
self
.
_dispatcher
=
self
.
_create_dispatcher
(
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
=
Thread
(
target
=
self
.
_dispatcher
.
run
)
self
.
_dispatcher_thread
.
start
()
self
.
_dispatcher_thread
.
start
()
...
@@ -128,32 +137,37 @@ class Experiment:
...
@@ -128,32 +137,37 @@ class Experiment:
if
interface
.
family
==
socket
.
AF_INET
:
if
interface
.
family
==
socket
.
AF_INET
:
ips
.
append
(
interface
.
address
)
ips
.
append
(
interface
.
address
)
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
ips
=
[
f
'http://
{
ip
}
:
{
port
}
'
for
ip
in
ips
if
ip
]
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
msg
=
'Web UI URLs: '
+
colorama
.
Fore
.
CYAN
+
' '
.
join
(
ips
)
+
colorama
.
Style
.
RESET_ALL
_logger
.
info
(
msg
)
_logger
.
info
(
msg
)
# TODO: register experiment management metadata
def
_create_dispatcher
(
self
):
# overrided by retiarii, temporary solution
return
MsgDispatcher
(
self
.
tuner
,
None
)
def
stop
(
self
)
->
None
:
def
stop
(
self
)
->
None
:
"""
"""
Stop background experiment.
Stop background experiment.
"""
"""
_logger
.
info
(
'Stopping experiment...'
)
_logger
.
info
(
'Stopping experiment
, please wait
...'
)
atexit
.
unregister
(
self
.
stop
)
atexit
.
unregister
(
self
.
stop
)
if
self
.
id
is
not
None
:
nni
.
runtime
.
log
.
stop_experiment_log
(
self
.
id
)
if
self
.
_proc
is
not
None
:
if
self
.
_proc
is
not
None
:
self
.
_proc
.
kill
(
)
kill_command
(
self
.
_proc
.
pid
)
if
self
.
_pipe
is
not
None
:
if
self
.
_pipe
is
not
None
:
self
.
_pipe
.
close
()
self
.
_pipe
.
close
()
if
self
.
_dispatcher_thread
is
not
None
:
if
self
.
_dispatcher_thread
is
not
None
:
self
.
_dispatcher
.
stopping
=
True
self
.
_dispatcher
.
stopping
=
True
self
.
_dispatcher_thread
.
join
(
timeout
=
1
)
self
.
_dispatcher_thread
.
join
(
timeout
=
1
)
self
.
id
=
None
self
.
port
=
None
self
.
port
=
None
self
.
_proc
=
None
self
.
_proc
=
None
self
.
_pipe
=
None
self
.
_pipe
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher
=
None
self
.
_dispatcher_thread
=
None
self
.
_dispatcher_thread
=
None
_logger
.
info
(
'Experiment stopped'
)
def
run
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
bool
:
def
run
(
self
,
port
:
int
=
8080
,
debug
:
bool
=
False
)
->
bool
:
...
@@ -169,10 +183,12 @@ class Experiment:
...
@@ -169,10 +183,12 @@ class Experiment:
while
True
:
while
True
:
time
.
sleep
(
10
)
time
.
sleep
(
10
)
status
=
self
.
get_status
()
status
=
self
.
get_status
()
if
status
==
'STOPPED'
:
if
status
==
'DONE'
or
status
==
'STOPPED'
:
return
True
return
True
if
status
==
'ERROR'
:
if
status
==
'ERROR'
:
return
False
return
False
except
KeyboardInterrupt
:
_logger
.
warning
(
'KeyboardInterrupt detected'
)
finally
:
finally
:
self
.
stop
()
self
.
stop
()
...
...
nni/experiment/launcher.py
View file @
4784cc6c
...
@@ -14,33 +14,37 @@ import nni_node
...
@@ -14,33 +14,37 @@ import nni_node
from
.config
import
ExperimentConfig
from
.config
import
ExperimentConfig
from
.config
import
convert
from
.config
import
convert
from
.
import
management
from
.pipe
import
Pipe
from
.pipe
import
Pipe
from
.
import
rest
from
.
import
rest
from
..tools.nnictl.config_utils
import
Experiments
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
_logger
=
logging
.
getLogger
(
'nni.experiment'
)
def
start_experiment
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Tuple
[
Popen
,
Pipe
]:
def
start_experiment
(
exp_id
:
str
,
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
)
->
Tuple
[
Popen
,
Pipe
]:
pipe
=
None
pipe
=
None
proc
=
None
proc
=
None
config
.
validate
(
initialized_tuner
=
True
)
config
.
validate
(
initialized_tuner
=
True
)
_ensure_port_idle
(
port
)
_ensure_port_idle
(
port
)
if
config
.
training_service
.
platform
==
'openpai'
:
if
isinstance
(
config
.
training_service
,
list
):
# hybrid training service
_ensure_port_idle
(
port
+
1
,
'OpenPAI requires an additional port'
)
_ensure_port_idle
(
port
+
1
,
'Hybrid training service requires an additional port'
)
exp_id
=
management
.
generate_experiment_id
()
elif
config
.
training_service
.
platform
in
[
'remote'
,
'openpai'
,
'kubeflow'
,
'frameworkcontroller'
,
'adl'
]:
_ensure_port_idle
(
port
+
1
,
f
'
{
config
.
training_service
.
platform
}
requires an additional port'
)
try
:
try
:
_logger
.
info
(
'Creating experiment
%s
%s'
,
colorama
.
Fore
.
CYAN
,
exp_id
)
_logger
.
info
(
'Creating experiment
, Experiment ID:
%s'
,
colorama
.
Fore
.
CYAN
+
exp_id
+
colorama
.
Style
.
RESET_ALL
)
pipe
=
Pipe
(
exp_id
)
pipe
=
Pipe
(
exp_id
)
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
start_time
,
proc
=
_start_rest_server
(
config
,
port
,
debug
,
exp_id
,
pipe
.
path
)
_logger
.
info
(
'Connecting IPC pipe...'
)
_logger
.
info
(
'Connecting IPC pipe...'
)
pipe_file
=
pipe
.
connect
()
pipe_file
=
pipe
.
connect
()
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_in_file
=
pipe_file
nni
.
runtime
.
protocol
.
_out_file
=
pipe_file
nni
.
runtime
.
protocol
.
_out_file
=
pipe_file
_logger
.
info
(
'Statring web server...'
)
_logger
.
info
(
'Statring web server...'
)
_check_rest_server
(
port
)
_check_rest_server
(
port
)
platform
=
'hybrid'
if
isinstance
(
config
.
training_service
,
list
)
else
config
.
training_service
.
platform
_save_experiment_information
(
exp_id
,
port
,
start_time
,
platform
,
config
.
experiment_name
,
proc
.
pid
,
config
.
experiment_working_directory
)
_logger
.
info
(
'Setting up...'
)
_logger
.
info
(
'Setting up...'
)
_init_experiment
(
config
,
port
,
debug
)
_init_experiment
(
config
,
port
,
debug
)
return
proc
,
pipe
return
proc
,
pipe
...
@@ -64,10 +68,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
...
@@ -64,10 +68,13 @@ def _ensure_port_idle(port: int, message: Optional[str] = None) -> None:
raise
RuntimeError
(
f
'Port
{
port
}
is not idle
{
message
}
'
)
raise
RuntimeError
(
f
'Port
{
port
}
is not idle
{
message
}
'
)
def
_start_rest_server
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
pipe_path
:
str
)
->
Popen
:
def
_start_rest_server
(
config
:
ExperimentConfig
,
port
:
int
,
debug
:
bool
,
experiment_id
:
str
,
pipe_path
:
str
)
->
Tuple
[
int
,
Popen
]:
ts
=
config
.
training_service
.
platform
if
isinstance
(
config
.
training_service
,
list
):
if
ts
==
'openpai'
:
ts
=
'hybrid'
ts
=
'pai'
else
:
ts
=
config
.
training_service
.
platform
if
ts
==
'openpai'
:
ts
=
'pai'
args
=
{
args
=
{
'port'
:
port
,
'port'
:
port
,
...
@@ -85,7 +92,13 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
...
@@ -85,7 +92,13 @@ def _start_rest_server(config: ExperimentConfig, port: int, debug: bool, experim
for
arg_key
,
arg_value
in
args
.
items
():
for
arg_key
,
arg_value
in
args
.
items
():
cmd
.
append
(
'--'
+
arg_key
)
cmd
.
append
(
'--'
+
arg_key
)
cmd
.
append
(
str
(
arg_value
))
cmd
.
append
(
str
(
arg_value
))
return
Popen
(
cmd
,
cwd
=
node_dir
)
if
sys
.
platform
==
'win32'
:
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
proc
=
Popen
(
cmd
,
cwd
=
node_dir
,
creationflags
=
CREATE_NEW_PROCESS_GROUP
)
else
:
proc
=
Popen
(
cmd
,
cwd
=
node_dir
)
return
int
(
time
.
time
()
*
1000
),
proc
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
3
)
->
None
:
def
_check_rest_server
(
port
:
int
,
retry
:
int
=
3
)
->
None
:
...
@@ -103,3 +116,8 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
...
@@ -103,3 +116,8 @@ def _init_experiment(config: ExperimentConfig, port: int, debug: bool) -> None:
for
cluster_metadata
in
convert
.
to_cluster_metadata
(
config
):
for
cluster_metadata
in
convert
.
to_cluster_metadata
(
config
):
rest
.
put
(
port
,
'/experiment/cluster-metadata'
,
cluster_metadata
)
rest
.
put
(
port
,
'/experiment/cluster-metadata'
,
cluster_metadata
)
rest
.
post
(
port
,
'/experiment'
,
convert
.
to_rest_json
(
config
))
rest
.
post
(
port
,
'/experiment'
,
convert
.
to_rest_json
(
config
))
def
_save_experiment_information
(
experiment_id
:
str
,
port
:
int
,
start_time
:
int
,
platform
:
str
,
name
:
str
,
pid
:
int
,
logDir
:
str
)
->
None
:
experiment_config
=
Experiments
()
experiment_config
.
add_experiment
(
experiment_id
,
port
,
start_time
,
platform
,
name
,
pid
=
pid
,
logDir
=
logDir
)
nni/experiment/pipe.py
View file @
4784cc6c
...
@@ -31,7 +31,6 @@ if sys.platform == 'win32':
...
@@ -31,7 +31,6 @@ if sys.platform == 'win32':
def
close
(
self
)
->
None
:
def
close
(
self
)
->
None
:
if
self
.
file
is
not
None
:
if
self
.
file
is
not
None
:
self
.
file
.
close
()
self
.
file
.
close
()
_winapi
.
CloseHandle
(
self
.
_handle
)
Pipe
=
WindowsPipe
Pipe
=
WindowsPipe
...
...
nni/nas/pytorch/base_mutator.py
View file @
4784cc6c
...
@@ -110,7 +110,7 @@ class BaseMutator(nn.Module):
...
@@ -110,7 +110,7 @@ class BaseMutator(nn.Module):
Parameters
Parameters
----------
----------
mutable : LayerChoice
mutable :
nni.nas.pytorch.mutables.
LayerChoice
Module whose forward is called.
Module whose forward is called.
args : list of torch.Tensor
args : list of torch.Tensor
The arguments of its forward function.
The arguments of its forward function.
...
@@ -130,7 +130,7 @@ class BaseMutator(nn.Module):
...
@@ -130,7 +130,7 @@ class BaseMutator(nn.Module):
Parameters
Parameters
----------
----------
mutable : InputChoice
mutable :
nni.nas.pytorch.mutables.
InputChoice
Mutable that is called.
Mutable that is called.
tensor_list : list of torch.Tensor
tensor_list : list of torch.Tensor
The arguments mutable is called with.
The arguments mutable is called with.
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment