Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b40e3db7
"src/include/blockwise_gemm.hip.hpp" did not exist on "61ac08661d47ecf84f4e0afc59d3261c035d4226"
Commit
b40e3db7
authored
Dec 01, 2020
by
quzha
Browse files
Merge branch 'master' of github.com:Microsoft/nni into dev-retiarii
parents
efa4e31c
95f731e4
Changes
226
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1114 additions
and
309 deletions
+1114
-309
nni/algorithms/hpo/hyperband_advisor/hyperband_advisor.py
nni/algorithms/hpo/hyperband_advisor/hyperband_advisor.py
+53
-23
nni/algorithms/nas/pytorch/cdarts/utils.py
nni/algorithms/nas/pytorch/cdarts/utils.py
+1
-1
nni/algorithms/nas/pytorch/cream/__init__.py
nni/algorithms/nas/pytorch/cream/__init__.py
+4
-0
nni/algorithms/nas/pytorch/cream/trainer.py
nni/algorithms/nas/pytorch/cream/trainer.py
+406
-0
nni/algorithms/nas/pytorch/cream/utils.py
nni/algorithms/nas/pytorch/cream/utils.py
+37
-0
nni/algorithms/nas/pytorch/darts/trainer.py
nni/algorithms/nas/pytorch/darts/trainer.py
+1
-1
nni/common/graph_utils.py
nni/common/graph_utils.py
+72
-39
nni/compression/pytorch/utils/counter.py
nni/compression/pytorch/utils/counter.py
+262
-101
nni/compression/pytorch/utils/mask_conflict.py
nni/compression/pytorch/utils/mask_conflict.py
+5
-3
nni/experiment/nni_client.py
nni/experiment/nni_client.py
+2
-8
nni/runtime/msg_dispatcher.py
nni/runtime/msg_dispatcher.py
+5
-2
nni/runtime/platform/__init__.py
nni/runtime/platform/__init__.py
+1
-1
nni/tools/nnictl/common_utils.py
nni/tools/nnictl/common_utils.py
+36
-0
nni/tools/nnictl/config_schema.py
nni/tools/nnictl/config_schema.py
+26
-1
nni/tools/nnictl/config_utils.py
nni/tools/nnictl/config_utils.py
+44
-28
nni/tools/nnictl/constants.py
nni/tools/nnictl/constants.py
+7
-7
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+46
-40
nni/tools/nnictl/launcher_utils.py
nni/tools/nnictl/launcher_utils.py
+4
-0
nni/tools/nnictl/nnictl_utils.py
nni/tools/nnictl/nnictl_utils.py
+73
-46
nni/tools/nnictl/tensorboard_utils.py
nni/tools/nnictl/tensorboard_utils.py
+29
-8
No files found.
nni/algorithms/hpo/hyperband_advisor/hyperband_advisor.py
View file @
b40e3db7
...
...
@@ -46,7 +46,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
Parameters
----------
brackets_id: in
t
brackets_id:
str
in
g
brackets id
brackets_curr_decay:
brackets curr decay
...
...
@@ -60,7 +60,7 @@ def create_bracket_parameter_id(brackets_id, brackets_curr_decay, increased_id=-
"""
if
increased_id
==
-
1
:
increased_id
=
str
(
create_parameter_id
())
params_id
=
'_'
.
join
([
str
(
brackets_id
)
,
params_id
=
'_'
.
join
([
brackets_id
,
str
(
brackets_curr_decay
),
increased_id
])
return
params_id
...
...
@@ -108,6 +108,8 @@ class Bracket():
Parameters
----------
bracket_id: string
The id of this bracket, usually be set as '{Hyperband index}-{SH iteration index}'
s: int
The current SH iteration index.
s_max: int
...
...
@@ -122,8 +124,9 @@ class Bracket():
optimize mode, 'maximize' or 'minimize'
"""
def
__init__
(
self
,
s
,
s_max
,
eta
,
R
,
optimize_mode
):
self
.
bracket_id
=
s
def
__init__
(
self
,
bracket_id
,
s
,
s_max
,
eta
,
R
,
optimize_mode
):
self
.
bracket_id
=
bracket_id
self
.
s
=
s
self
.
s_max
=
s_max
self
.
eta
=
eta
self
.
n
=
math
.
ceil
((
s_max
+
1
)
*
(
eta
**
s
)
/
(
s
+
1
)
-
_epsilon
)
...
...
@@ -147,7 +150,7 @@ class Bracket():
def
increase_i
(
self
):
"""i means the ith round. Increase i by 1"""
self
.
i
+=
1
if
self
.
i
>
self
.
bracket_id
:
if
self
.
i
>
self
.
s
:
self
.
no_more_trial
=
True
def
set_config_perf
(
self
,
i
,
parameter_id
,
seq
,
value
):
...
...
@@ -256,13 +259,14 @@ class HyperbandClassArgsValidator(ClassArgsValidator):
def
validate_class_args
(
self
,
**
kwargs
):
Schema
({
'optimize_mode'
:
self
.
choices
(
'optimize_mode'
,
'maximize'
,
'minimize'
),
Optional
(
'exec_mode'
):
self
.
choices
(
'exec_mode'
,
'serial'
,
'parallelism'
),
Optional
(
'R'
):
int
,
Optional
(
'eta'
):
int
}).
validate
(
kwargs
)
class
Hyperband
(
MsgDispatcherBase
):
"""Hyperband inherit from MsgDispatcherBase rather than Tuner, because it integrates both tuner's functions and assessor's functions.
This is an implementation that could fully leverage available resources, i.e., high parallelism.
This is an implementation that could fully leverage available resources
or follow the algorithm process
, i.e., high parallelism
or serial
.
A single execution of Hyperband takes a finite budget of (s_max + 1)B.
Parameters
...
...
@@ -273,9 +277,11 @@ class Hyperband(MsgDispatcherBase):
the variable that controls the proportion of configurations discarded in each round of SuccessiveHalving
optimize_mode: str
optimize mode, 'maximize' or 'minimize'
exec_mode: str
execution mode, 'serial' or 'parallelism'
"""
def
__init__
(
self
,
R
=
60
,
eta
=
3
,
optimize_mode
=
'maximize'
):
def
__init__
(
self
,
R
=
60
,
eta
=
3
,
optimize_mode
=
'maximize'
,
exec_mode
=
'parallelism'
):
"""B = (s_max + 1)R"""
super
(
Hyperband
,
self
).
__init__
()
self
.
R
=
R
...
...
@@ -285,6 +291,9 @@ class Hyperband(MsgDispatcherBase):
self
.
completed_hyper_configs
=
[]
# all the completed configs
self
.
s_max
=
math
.
floor
(
math
.
log
(
self
.
R
,
self
.
eta
)
+
_epsilon
)
self
.
curr_s
=
self
.
s_max
self
.
curr_hb
=
0
self
.
exec_mode
=
exec_mode
self
.
curr_bracket_id
=
None
self
.
searchspace_json
=
None
self
.
random_state
=
None
...
...
@@ -316,25 +325,44 @@ class Hyperband(MsgDispatcherBase):
data: int
number of trial jobs
"""
for
_
in
range
(
data
):
ret
=
self
.
_get_one_trial_job
()
self
.
credit
+=
data
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
def
_request_one_trial_job
(
self
):
ret
=
self
.
_get_one_trial_job
()
if
ret
is
not
None
:
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
ret
))
self
.
credit
-=
1
def
_get_one_trial_job
(
self
):
"""get one trial job, i.e., one hyperparameter configuration."""
if
not
self
.
generated_hyper_configs
:
if
self
.
curr_s
<
0
:
self
.
curr_s
=
self
.
s_max
_logger
.
debug
(
'create a new bracket, self.curr_s=%d'
,
self
.
curr_s
)
self
.
brackets
[
self
.
curr_s
]
=
Bracket
(
self
.
curr_s
,
self
.
s_max
,
self
.
eta
,
self
.
R
,
self
.
optimize_mode
)
next_n
,
next_r
=
self
.
brackets
[
self
.
curr_s
].
get_n_r
()
_logger
.
debug
(
'new bracket, next_n=%d, next_r=%d'
,
next_n
,
next_r
)
assert
self
.
searchspace_json
is
not
None
and
self
.
random_state
is
not
None
generated_hyper_configs
=
self
.
brackets
[
self
.
curr_s
].
get_hyperparameter_configurations
(
next_n
,
next_r
,
self
.
searchspace_json
,
self
.
random_state
)
self
.
generated_hyper_configs
=
generated_hyper_configs
.
copy
()
self
.
curr_s
-=
1
if
self
.
exec_mode
==
'parallelism'
or
\
(
self
.
exec_mode
==
'serial'
and
(
self
.
curr_bracket_id
is
None
or
self
.
brackets
[
self
.
curr_bracket_id
].
is_completed
())):
if
self
.
curr_s
<
0
:
self
.
curr_s
=
self
.
s_max
self
.
curr_hb
+=
1
_logger
.
debug
(
'create a new bracket, self.curr_hb=%d, self.curr_s=%d'
,
self
.
curr_hb
,
self
.
curr_s
)
self
.
curr_bracket_id
=
'{}-{}'
.
format
(
self
.
curr_hb
,
self
.
curr_s
)
self
.
brackets
[
self
.
curr_bracket_id
]
=
Bracket
(
self
.
curr_bracket_id
,
self
.
curr_s
,
self
.
s_max
,
self
.
eta
,
self
.
R
,
self
.
optimize_mode
)
next_n
,
next_r
=
self
.
brackets
[
self
.
curr_bracket_id
].
get_n_r
()
_logger
.
debug
(
'new bracket, next_n=%d, next_r=%d'
,
next_n
,
next_r
)
assert
self
.
searchspace_json
is
not
None
and
self
.
random_state
is
not
None
generated_hyper_configs
=
self
.
brackets
[
self
.
curr_bracket_id
].
get_hyperparameter_configurations
(
next_n
,
next_r
,
self
.
searchspace_json
,
self
.
random_state
)
self
.
generated_hyper_configs
=
generated_hyper_configs
.
copy
()
self
.
curr_s
-=
1
else
:
ret
=
{
'parameter_id'
:
'-1_0_0'
,
'parameter_source'
:
'algorithm'
,
'parameters'
:
''
}
send
(
CommandType
.
NoMoreTrialJobs
,
json_tricks
.
dumps
(
ret
))
return
None
assert
self
.
generated_hyper_configs
params
=
self
.
generated_hyper_configs
.
pop
(
0
)
...
...
@@ -358,10 +386,12 @@ class Hyperband(MsgDispatcherBase):
parameter_id: parameter id of the finished config
"""
bracket_id
,
i
,
_
=
parameter_id
.
split
(
'_'
)
hyper_configs
=
self
.
brackets
[
int
(
bracket_id
)
].
inform_trial_end
(
int
(
i
))
hyper_configs
=
self
.
brackets
[
bracket_id
].
inform_trial_end
(
int
(
i
))
if
hyper_configs
is
not
None
:
_logger
.
debug
(
'bracket %s next round %s, hyper_configs: %s'
,
bracket_id
,
i
,
hyper_configs
)
self
.
generated_hyper_configs
=
self
.
generated_hyper_configs
+
hyper_configs
for
_
in
range
(
self
.
credit
):
self
.
_request_one_trial_job
()
def
handle_trial_end
(
self
,
data
):
"""
...
...
@@ -392,6 +422,7 @@ class Hyperband(MsgDispatcherBase):
"""
if
'value'
in
data
:
data
[
'value'
]
=
json_tricks
.
loads
(
data
[
'value'
])
# multiphase? need to check
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
assert
multi_phase_enabled
()
assert
data
[
'trial_job_id'
]
is
not
None
...
...
@@ -408,7 +439,6 @@ class Hyperband(MsgDispatcherBase):
else
:
value
=
extract_scalar_reward
(
data
[
'value'
])
bracket_id
,
i
,
_
=
data
[
'parameter_id'
].
split
(
'_'
)
bracket_id
=
int
(
bracket_id
)
# add <trial_job_id, parameter_id> to self.job_id_para_id_map here,
# because when the first parameter_id is created, trial_job_id is not known yet.
...
...
nni/algorithms/nas/pytorch/cdarts/utils.py
View file @
b40e3db7
...
...
@@ -58,7 +58,7 @@ def accuracy(output, target, topk=(1,)):
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
correct_k
=
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
)
res
.
append
(
correct_k
.
mul_
(
1.0
/
batch_size
))
return
res
...
...
nni/algorithms/nas/pytorch/cream/__init__.py
0 → 100755
View file @
b40e3db7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.trainer
import
CreamSupernetTrainer
nni/algorithms/nas/pytorch/cream/trainer.py
0 → 100644
View file @
b40e3db7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
torch
import
logging
from
copy
import
deepcopy
from
nni.nas.pytorch.trainer
import
Trainer
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
.utils
import
accuracy
,
reduce_metrics
logger
=
logging
.
getLogger
(
__name__
)
class
CreamSupernetTrainer
(
Trainer
):
"""
This trainer trains a supernet and output prioritized architectures that can be used for other tasks.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
val_loss : callable
Called with logits and targets for validation only. Returns a loss tensor.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
train_loader : iterablez
Data loader of training. Raise ``StopIteration`` when one epoch is exhausted.
valid_loader : iterablez
Data loader of validation. Raise ``StopIteration`` when one epoch is exhausted.
mutator : Mutator
A mutator object that has been initialized with the model.
batch_size : int
Batch size.
log_frequency : int
Number of mini-batches to log metrics.
meta_sta_epoch : int
start epoch of using meta matching network to pick teacher architecture
update_iter : int
interval of updating meta matching networks
slices : int
batch size of mini training data in the process of training meta matching network
pool_size : int
board size
pick_method : basestring
how to pick teacher network
choice_num : int
number of operations in supernet
sta_num : int
layer number of each stage in supernet (5 stage in supernet)
acc_gap : int
maximum accuracy improvement to omit the limitation of flops
flops_dict : Dict
dictionary of each layer's operations in supernet
flops_fixed : int
flops of fixed part in supernet
local_rank : int
index of current rank
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
def
__init__
(
self
,
model
,
loss
,
val_loss
,
optimizer
,
num_epochs
,
train_loader
,
valid_loader
,
mutator
=
None
,
batch_size
=
64
,
log_frequency
=
None
,
meta_sta_epoch
=
20
,
update_iter
=
200
,
slices
=
2
,
pool_size
=
10
,
pick_method
=
'meta'
,
choice_num
=
6
,
sta_num
=
(
4
,
4
,
4
,
4
,
4
),
acc_gap
=
5
,
flops_dict
=
None
,
flops_fixed
=
0
,
local_rank
=
0
,
callbacks
=
None
):
assert
torch
.
cuda
.
is_available
()
super
(
CreamSupernetTrainer
,
self
).
__init__
(
model
,
mutator
,
loss
,
None
,
optimizer
,
num_epochs
,
None
,
None
,
batch_size
,
None
,
None
,
log_frequency
,
callbacks
)
self
.
model
=
model
self
.
loss
=
loss
self
.
val_loss
=
val_loss
self
.
train_loader
=
train_loader
self
.
valid_loader
=
valid_loader
self
.
log_frequency
=
log_frequency
self
.
batch_size
=
batch_size
self
.
optimizer
=
optimizer
self
.
model
=
model
self
.
loss
=
loss
self
.
num_epochs
=
num_epochs
self
.
meta_sta_epoch
=
meta_sta_epoch
self
.
update_iter
=
update_iter
self
.
slices
=
slices
self
.
pick_method
=
pick_method
self
.
pool_size
=
pool_size
self
.
local_rank
=
local_rank
self
.
choice_num
=
choice_num
self
.
sta_num
=
sta_num
self
.
acc_gap
=
acc_gap
self
.
flops_dict
=
flops_dict
self
.
flops_fixed
=
flops_fixed
self
.
current_student_arch
=
None
self
.
current_teacher_arch
=
None
self
.
main_proc
=
(
local_rank
==
0
)
self
.
current_epoch
=
0
self
.
prioritized_board
=
[]
# size of prioritized board
def
_board_size
(
self
):
return
len
(
self
.
prioritized_board
)
# select teacher architecture according to the logit difference
def
_select_teacher
(
self
):
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
if
self
.
pick_method
==
'top1'
:
meta_value
,
teacher_cand
=
0.5
,
sorted
(
self
.
prioritized_board
,
reverse
=
True
)[
0
][
3
]
elif
self
.
pick_method
==
'meta'
:
meta_value
,
cand_idx
,
teacher_cand
=
-
1000000000
,
-
1
,
None
for
now_idx
,
item
in
enumerate
(
self
.
prioritized_board
):
inputx
=
item
[
4
]
output
=
torch
.
nn
.
functional
.
softmax
(
self
.
model
(
inputx
),
dim
=
1
)
weight
=
self
.
model
.
module
.
forward_meta
(
output
-
item
[
5
])
if
weight
>
meta_value
:
meta_value
=
weight
cand_idx
=
now_idx
teacher_cand
=
self
.
prioritized_board
[
cand_idx
][
3
]
assert
teacher_cand
is
not
None
meta_value
=
torch
.
nn
.
functional
.
sigmoid
(
-
weight
)
else
:
raise
ValueError
(
'Method Not supported'
)
return
meta_value
,
teacher_cand
# check whether to update prioritized board
def
_isUpdateBoard
(
self
,
prec1
,
flops
):
if
self
.
current_epoch
<=
self
.
meta_sta_epoch
:
return
False
if
len
(
self
.
prioritized_board
)
<
self
.
pool_size
:
return
True
if
prec1
>
self
.
prioritized_board
[
-
1
][
1
]
+
self
.
acc_gap
:
return
True
if
prec1
>
self
.
prioritized_board
[
-
1
][
1
]
and
flops
<
self
.
prioritized_board
[
-
1
][
2
]:
return
True
return
False
# update prioritized board
def
_update_prioritized_board
(
self
,
inputs
,
teacher_output
,
outputs
,
prec1
,
flops
):
if
self
.
_isUpdateBoard
(
prec1
,
flops
):
val_prec1
=
prec1
training_data
=
deepcopy
(
inputs
[:
self
.
slices
].
detach
())
if
len
(
self
.
prioritized_board
)
==
0
:
features
=
deepcopy
(
outputs
[:
self
.
slices
].
detach
())
else
:
features
=
deepcopy
(
teacher_output
[:
self
.
slices
].
detach
())
self
.
prioritized_board
.
append
(
(
val_prec1
,
prec1
,
flops
,
self
.
current_teacher_arch
,
training_data
,
torch
.
nn
.
functional
.
softmax
(
features
,
dim
=
1
)))
self
.
prioritized_board
=
sorted
(
self
.
prioritized_board
,
reverse
=
True
)
if
len
(
self
.
prioritized_board
)
>
self
.
pool_size
:
self
.
prioritized_board
=
sorted
(
self
.
prioritized_board
,
reverse
=
True
)
del
self
.
prioritized_board
[
-
1
]
# only update student network weights
def
_update_student_weights_only
(
self
,
grad_1
):
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
grad_1
):
weight
.
grad
=
grad_item
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
1
)
self
.
optimizer
.
step
()
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
grad_1
):
del
weight
.
grad
# only update meta networks weights
def
_update_meta_weights_only
(
self
,
teacher_cand
,
grad_teacher
):
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
teacher_cand
,
self
.
pick_method
==
'meta'
),
grad_teacher
):
weight
.
grad
=
grad_item
# clip gradients
torch
.
nn
.
utils
.
clip_grad_norm_
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
,
self
.
pick_method
==
'meta'
),
1
)
self
.
optimizer
.
step
()
for
weight
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
teacher_cand
,
self
.
pick_method
==
'meta'
),
grad_teacher
):
del
weight
.
grad
# simulate sgd updating
def
_simulate_sgd_update
(
self
,
w
,
g
,
optimizer
):
return
g
*
optimizer
.
param_groups
[
-
1
][
'lr'
]
+
w
# split training images into several slices
def
_get_minibatch_input
(
self
,
input
):
slice
=
self
.
slices
x
=
deepcopy
(
input
[:
slice
].
clone
().
detach
())
return
x
# calculate 1st gradient of student architectures
def
_calculate_1st_gradient
(
self
,
kd_loss
):
self
.
optimizer
.
zero_grad
()
grad
=
torch
.
autograd
.
grad
(
kd_loss
,
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
create_graph
=
True
)
return
grad
# calculate 2nd gradient of meta networks
def
_calculate_2nd_gradient
(
self
,
validation_loss
,
teacher_cand
,
students_weight
):
self
.
optimizer
.
zero_grad
()
grad_student_val
=
torch
.
autograd
.
grad
(
validation_loss
,
self
.
model
.
module
.
rand_parameters
(
self
.
random_cand
),
retain_graph
=
True
)
grad_teacher
=
torch
.
autograd
.
grad
(
students_weight
[
0
],
self
.
model
.
module
.
rand_parameters
(
teacher_cand
,
self
.
pick_method
==
'meta'
),
grad_outputs
=
grad_student_val
)
return
grad_teacher
# forward training data
def
_forward_training
(
self
,
x
,
meta_value
):
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output
=
self
.
model
(
x
)
with
torch
.
no_grad
():
self
.
_replace_mutator_cand
(
self
.
current_teacher_arch
)
teacher_output
=
self
.
model
(
x
)
soft_label
=
torch
.
nn
.
functional
.
softmax
(
teacher_output
,
dim
=
1
)
kd_loss
=
meta_value
*
\
self
.
_cross_entropy_loss_with_soft_target
(
output
,
soft_label
)
return
kd_loss
# calculate soft target loss
def
_cross_entropy_loss_with_soft_target
(
self
,
pred
,
soft_target
):
logsoftmax
=
torch
.
nn
.
LogSoftmax
()
return
torch
.
mean
(
torch
.
sum
(
-
soft_target
*
logsoftmax
(
pred
),
1
))
# forward validation data
def
_forward_validation
(
self
,
input
,
target
):
slice
=
self
.
slices
x
=
input
[
slice
:
slice
*
2
].
clone
()
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output_2
=
self
.
model
(
x
)
validation_loss
=
self
.
loss
(
output_2
,
target
[
slice
:
slice
*
2
])
return
validation_loss
def
_isUpdateMeta
(
self
,
batch_idx
):
isUpdate
=
True
isUpdate
&=
(
self
.
current_epoch
>
self
.
meta_sta_epoch
)
isUpdate
&=
(
batch_idx
>
0
)
isUpdate
&=
(
batch_idx
%
self
.
update_iter
==
0
)
isUpdate
&=
(
self
.
_board_size
()
>
0
)
return
isUpdate
def
_replace_mutator_cand
(
self
,
cand
):
self
.
mutator
.
_cache
=
cand
# update meta matching networks
def
_run_update
(
self
,
input
,
target
,
batch_idx
):
if
self
.
_isUpdateMeta
(
batch_idx
):
x
=
self
.
_get_minibatch_input
(
input
)
meta_value
,
teacher_cand
=
self
.
_select_teacher
()
kd_loss
=
self
.
_forward_training
(
x
,
meta_value
)
# calculate 1st gradient
grad_1st
=
self
.
_calculate_1st_gradient
(
kd_loss
)
# simulate updated student weights
students_weight
=
[
self
.
_simulate_sgd_update
(
p
,
grad_item
,
self
.
optimizer
)
for
p
,
grad_item
in
zip
(
self
.
model
.
module
.
rand_parameters
(
self
.
current_student_arch
),
grad_1st
)]
# update student weights
self
.
_update_student_weights_only
(
grad_1st
)
validation_loss
=
self
.
_forward_validation
(
input
,
target
)
# calculate 2nd gradient
grad_teacher
=
self
.
_calculate_2nd_gradient
(
validation_loss
,
teacher_cand
,
students_weight
)
# update meta matching networks
self
.
_update_meta_weights_only
(
teacher_cand
,
grad_teacher
)
# delete internal variants
del
grad_teacher
,
grad_1st
,
x
,
validation_loss
,
kd_loss
,
students_weight
def
_get_cand_flops
(
self
,
cand
):
flops
=
0
for
block_id
,
block
in
enumerate
(
cand
):
if
block
==
'LayerChoice1'
or
block_id
==
'LayerChoice23'
:
continue
for
idx
,
choice
in
enumerate
(
cand
[
block
]):
flops
+=
self
.
flops_dict
[
block_id
][
idx
]
*
(
1
if
choice
else
0
)
return
flops
+
self
.
flops_fixed
def
train_one_epoch
(
self
,
epoch
):
self
.
current_epoch
=
epoch
meters
=
AverageMeterGroup
()
self
.
steps_per_epoch
=
len
(
self
.
train_loader
)
for
step
,
(
input_data
,
target
)
in
enumerate
(
self
.
train_loader
):
self
.
mutator
.
reset
()
self
.
current_student_arch
=
self
.
mutator
.
_cache
input_data
,
target
=
input_data
.
cuda
(),
target
.
cuda
()
# calculate flops of current architecture
cand_flops
=
self
.
_get_cand_flops
(
self
.
mutator
.
_cache
)
# update meta matching network
self
.
_run_update
(
input_data
,
target
,
step
)
if
self
.
_board_size
()
>
0
:
# select teacher architecture
meta_value
,
teacher_cand
=
self
.
_select_teacher
()
self
.
current_teacher_arch
=
teacher_cand
# forward supernet
if
self
.
_board_size
()
==
0
or
epoch
<=
self
.
meta_sta_epoch
:
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output
=
self
.
model
(
input_data
)
loss
=
self
.
loss
(
output
,
target
)
kd_loss
,
teacher_output
,
teacher_cand
=
None
,
None
,
None
else
:
self
.
_replace_mutator_cand
(
self
.
current_student_arch
)
output
=
self
.
model
(
input_data
)
gt_loss
=
self
.
loss
(
output
,
target
)
with
torch
.
no_grad
():
self
.
_replace_mutator_cand
(
self
.
current_teacher_arch
)
teacher_output
=
self
.
model
(
input_data
).
detach
()
soft_label
=
torch
.
nn
.
functional
.
softmax
(
teacher_output
,
dim
=
1
)
kd_loss
=
self
.
_cross_entropy_loss_with_soft_target
(
output
,
soft_label
)
loss
=
(
meta_value
*
kd_loss
+
(
2
-
meta_value
)
*
gt_loss
)
/
2
# update network
self
.
optimizer
.
zero_grad
()
loss
.
backward
()
self
.
optimizer
.
step
()
# update metrics
prec1
,
prec5
=
accuracy
(
output
,
target
,
topk
=
(
1
,
5
))
metrics
=
{
"prec1"
:
prec1
,
"prec5"
:
prec5
,
"loss"
:
loss
}
metrics
=
reduce_metrics
(
metrics
)
meters
.
update
(
metrics
)
# update prioritized board
self
.
_update_prioritized_board
(
input_data
,
teacher_output
,
output
,
metrics
[
'prec1'
],
cand_flops
)
if
self
.
main_proc
and
(
step
%
self
.
log_frequency
==
0
or
step
+
1
==
self
.
steps_per_epoch
):
logger
.
info
(
"Epoch [%d/%d] Step [%d/%d] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
train_loader
),
meters
)
if
self
.
main_proc
and
self
.
num_epochs
==
epoch
+
1
:
for
idx
,
i
in
enumerate
(
self
.
best_children_pool
):
logger
.
info
(
"No.%s %s"
,
idx
,
i
[:
4
])
def
validate_one_epoch
(
self
,
epoch
):
self
.
model
.
eval
()
meters
=
AverageMeterGroup
()
with
torch
.
no_grad
():
for
step
,
(
x
,
y
)
in
enumerate
(
self
.
valid_loader
):
self
.
mutator
.
reset
()
logits
=
self
.
model
(
x
)
loss
=
self
.
val_loss
(
logits
,
y
)
prec1
,
prec5
=
self
.
accuracy
(
logits
,
y
,
topk
=
(
1
,
5
))
metrics
=
{
"prec1"
:
prec1
,
"prec5"
:
prec5
,
"loss"
:
loss
}
metrics
=
self
.
reduce_metrics
(
metrics
,
self
.
distributed
)
meters
.
update
(
metrics
)
if
self
.
log_frequency
is
not
None
and
step
%
self
.
log_frequency
==
0
:
logger
.
info
(
"Epoch [%s/%s] Validation Step [%s/%s] %s"
,
epoch
+
1
,
self
.
num_epochs
,
step
+
1
,
len
(
self
.
valid_loader
),
meters
)
nni/algorithms/nas/pytorch/cream/utils.py
0 → 100644
View file @
b40e3db7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
os
import
torch.distributed
as
dist
def
accuracy
(
output
,
target
,
topk
=
(
1
,)):
""" Computes the precision@k for the specified values of k """
maxk
=
max
(
topk
)
batch_size
=
target
.
size
(
0
)
_
,
pred
=
output
.
topk
(
maxk
,
1
,
True
,
True
)
pred
=
pred
.
t
()
# one-hot case
if
target
.
ndimension
()
>
1
:
target
=
target
.
max
(
1
)[
1
]
correct
=
pred
.
eq
(
target
.
view
(
1
,
-
1
).
expand_as
(
pred
))
res
=
[]
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
res
.
append
(
correct_k
.
mul_
(
1.0
/
batch_size
))
return
res
def
reduce_metrics
(
metrics
):
return
{
k
:
reduce_tensor
(
v
).
item
()
for
k
,
v
in
metrics
.
items
()}
def
reduce_tensor
(
tensor
):
rt
=
tensor
.
clone
()
dist
.
all_reduce
(
rt
,
op
=
dist
.
ReduceOp
.
SUM
)
rt
/=
float
(
os
.
environ
[
"WORLD_SIZE"
])
return
rt
nni/algorithms/nas/pytorch/darts/trainer.py
View file @
b40e3db7
...
...
@@ -210,5 +210,5 @@ class DartsTrainer(Trainer):
dalphas
.
append
(
torch
.
autograd
.
grad
(
loss
,
self
.
mutator
.
parameters
()))
dalpha_pos
,
dalpha_neg
=
dalphas
# dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian
=
[(
p
-
n
)
/
2.
*
eps
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
hessian
=
[(
p
-
n
)
/
(
2.
*
eps
)
for
p
,
n
in
zip
(
dalpha_pos
,
dalpha_neg
)]
return
hessian
nni/common/graph_utils.py
View file @
b40e3db7
...
...
@@ -15,6 +15,7 @@ LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND
=
'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND
=
'prim::TupleConstruct'
TUPLE_UNPACK_KIND
=
'prim::TupleUnpack'
CONSTANT_KIND
=
'prim::Constant'
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -68,9 +69,11 @@ class TorchGraph:
'Please provide model & dummy_input or the traced_model as inputs'
)
def
_trace
(
self
,
model
,
dummy_input
):
with
torch
.
onnx
.
set_training
(
model
,
False
):
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
training
=
model
.
training
model
.
eval
()
self
.
trace
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
torch
.
_C
.
_jit_pass_inline
(
self
.
trace
.
graph
)
model
.
train
(
training
)
class
TorchProtoGraph
(
TorchGraph
):
...
...
@@ -282,27 +285,35 @@ class TorchModuleGraph(TorchGraph):
self
.
global_count
+=
1
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
li
st
()
outputs
=
li
st
()
inputs
=
s
e
t
()
outputs
=
s
e
t
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
if
_input
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
not
self
.
_is_key_func
(
predecessor_node
):
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
append
(
input_name
)
if
input_name
in
output_to_node
:
for
predecessor_node
in
output_to_node
[
input_name
]:
if
predecessor_node
in
nodes
:
if
not
self
.
_is_key_func
(
predecessor_node
):
if
predecessor_node
not
in
node_group
:
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
else
:
inputs
.
add
(
input_name
)
else
:
inputs
.
add
(
input_name
)
else
:
inputs
.
a
ppen
d
(
input_name
)
inputs
.
a
d
d
(
input_name
)
for
output
in
node
.
outputs
():
outputs
.
append
(
output
.
debugName
())
if
output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
outputs
.
add
(
output
.
debugName
())
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
,
key_node
=
node
)
node_group
,
inputs
=
list
(
inputs
)
,
outputs
=
list
(
outputs
)
,
key_node
=
node
)
return
nodepy
def
_expand_module_node
(
self
,
node
,
node_name
,
unique_name
,
op_type
,
nodes
,
...
...
@@ -342,36 +353,46 @@ class TorchModuleGraph(TorchGraph):
if
not
op_type
:
op_type
=
node
.
kind
()
node_group
=
[
node
]
inputs
=
li
st
()
outputs
=
li
st
()
inputs
=
s
e
t
()
outputs
=
s
e
t
()
node_queue
=
queue
.
Queue
()
node_queue
.
put
(
node
)
visited
=
{
node
}
while
not
node_queue
.
empty
():
curr_node
=
node_queue
.
get
()
for
_input
in
curr_node
.
inputs
():
if
_input
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
input_name
=
_input
.
debugName
()
if
input_name
in
output_to_node
and
output_to_node
[
input_name
]
in
nodes
:
predecessor_node
=
output_to_node
[
input_name
]
if
predecessor_node
not
in
visited
:
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
if
input_name
in
output_to_node
:
for
predecessor_node
in
output_to_node
[
input_name
]:
if
predecessor_node
in
nodes
:
if
predecessor_node
not
in
visited
:
node_group
.
append
(
predecessor_node
)
node_queue
.
put
(
predecessor_node
)
visited
.
add
(
predecessor_node
)
else
:
inputs
.
add
(
input_name
)
else
:
inputs
.
a
ppen
d
(
input_name
)
inputs
.
a
d
d
(
input_name
)
for
_output
in
curr_node
.
outputs
():
if
_output
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
output_name
=
_output
.
debugName
()
if
output_name
in
input_to_node
and
input_to_node
[
output_name
]
in
nodes
:
successor_node
=
input_to_node
[
output_name
]
if
successor_node
not
in
visited
:
node_group
.
append
(
successor_node
)
node_queue
.
put
(
successor_node
)
visited
.
add
(
successor_node
)
if
output_name
in
input_to_node
:
for
successor_node
in
input_to_node
[
output_name
]:
if
successor_node
in
nodes
:
if
successor_node
not
in
visited
:
node_group
.
append
(
successor_node
)
node_queue
.
put
(
successor_node
)
visited
.
add
(
successor_node
)
else
:
outputs
.
add
(
output_name
)
else
:
outputs
.
a
ppen
d
(
output_name
)
outputs
.
a
d
d
(
output_name
)
nodepy
=
NodePyGroup
(
node_name
,
unique_name
,
module_type
,
op_type
,
node_group
,
inputs
=
inputs
,
outputs
=
outputs
)
node_group
,
inputs
=
list
(
inputs
)
,
outputs
=
list
(
outputs
)
)
return
nodepy
def
_extract_cat_info
(
self
,
node_group
,
cpp_node
):
...
...
@@ -544,7 +565,7 @@ class TorchModuleGraph(TorchGraph):
input_to_node
[
_input
].
append
(
node
)
for
output
in
node
.
outputs
:
assert
not
output
in
output_to_node
,
\
"One output cannot be generated by multiple nodes
"
"One output cannot be generated by multiple nodes
%s"
%
output
output_to_node
[
output
]
=
node
return
name_to_node
,
input_to_node
,
output_to_node
...
...
@@ -642,12 +663,22 @@ class TorchModuleGraph(TorchGraph):
omit_useless_nodes
=
True
graph
=
self
.
trace
.
graph
_logger
.
debug
(
graph
)
# build output mapping, from output debugName to its node
output_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
outputs
()}
# build input mapping, from input debugName to its node
input_to_node
=
{
x
.
debugName
():
n
for
n
in
graph
.
nodes
()
for
x
in
n
.
inputs
()}
# build input/output mapping, from input/output debugName to its node
input_to_node
=
defaultdict
(
list
)
output_to_node
=
defaultdict
(
list
)
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
CONSTANT_KIND
:
continue
for
x
in
node
.
outputs
():
if
x
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
output_to_node
[
x
.
debugName
()].
append
(
node
)
assert
len
(
output_to_node
[
x
.
debugName
()])
<=
1
,
"One output cannot be generated by multiple nodes %s"
%
x
.
debugName
()
for
x
in
node
.
inputs
():
if
x
.
node
().
kind
()
==
CONSTANT_KIND
:
continue
input_to_node
[
x
.
debugName
()].
append
(
node
)
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes
=
defaultdict
(
list
)
# the mapping of function (non-module in forward) to nodes, key is scope name
...
...
@@ -668,6 +699,8 @@ class TorchModuleGraph(TorchGraph):
# associate module name with their trace graph nodes
for
node
in
graph
.
nodes
():
if
node
.
kind
()
==
CONSTANT_KIND
:
continue
module_name
=
self
.
_get_module_name
(
node
.
scopeName
())
if
module_name
in
self
.
leaf_modules
:
module_to_nodes
[
module_name
].
append
(
node
)
...
...
nni/compression/pytorch/utils/counter.py
View file @
b40e3db7
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
functools
from
collections
import
Counter
from
prettytable
import
PrettyTable
import
torch
import
torch.nn
as
nn
from
nni.compression.pytorch.compressor
import
PrunerModuleWrapper
try
:
from
thop
import
profile
except
Exception
as
e
:
print
(
'thop is not found, please install the python package: thop'
)
raise
__all__
=
[
'count_flops_params'
]
def
count_flops_params
(
model
:
nn
.
Module
,
input_size
,
custom_ops
=
None
,
verbose
=
True
):
"""
Count FLOPs and Params of the given model.
This function would identify the mask on the module
and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify
the remained filters according to its mask, which
not taking the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters
---------
model : nn.Module
target model.
input_size: list, tuple
the input shape of data
custom_ops: dict
a mapping of (module: custom operation)
the custom operation will overwrite the default operation.
for reference, please see ``custom_mask_ops``.
def
_get_params
(
m
):
return
sum
([
p
.
numel
()
for
p
in
m
.
parameters
()])
Returns
-------
flops: float
total flops of the model
params:
total params of the model
"""
assert
input_size
is
not
None
class
ModelProfiler
:
device
=
next
(
model
.
parameters
()).
device
inputs
=
torch
.
randn
(
input_size
).
to
(
device
)
def
__init__
(
self
,
custom_ops
=
None
,
mode
=
'default'
):
"""
ModelProfiler is used to share state to hooks.
hook_module_list
=
[]
if
custom_ops
is
None
:
custom_ops
=
{}
custom_mask_ops
.
update
(
custom_ops
)
prev_m
=
None
for
m
in
model
.
modules
():
weight_mask
=
None
m_type
=
type
(
m
)
if
m_type
in
custom_mask_ops
:
if
isinstance
(
prev_m
,
PrunerModuleWrapper
):
weight_mask
=
prev_m
.
weight_mask
m
.
register_buffer
(
'weight_mask'
,
weight_mask
)
hook_module_list
.
append
(
m
)
prev_m
=
m
Parameters
----------
custom_ops: dict
a mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops, parameters and the weight shape, it will overwrite the default operation.
for reference, please see ``self.ops``.
mode:
the mode of how to collect information. If the mode is set to `default`,
only the information of convolution and linear will be collected.
If the mode is set to `full`, other operations will also be collected.
"""
self
.
ops
=
{
nn
.
Conv1d
:
self
.
_count_convNd
,
nn
.
Conv2d
:
self
.
_count_convNd
,
nn
.
Conv3d
:
self
.
_count_convNd
,
nn
.
Linear
:
self
.
_count_linear
}
self
.
_count_bias
=
False
if
mode
==
'full'
:
self
.
ops
.
update
({
nn
.
ConvTranspose1d
:
self
.
_count_convNd
,
nn
.
ConvTranspose2d
:
self
.
_count_convNd
,
nn
.
ConvTranspose3d
:
self
.
_count_convNd
,
nn
.
BatchNorm1d
:
self
.
_count_bn
,
nn
.
BatchNorm2d
:
self
.
_count_bn
,
nn
.
BatchNorm3d
:
self
.
_count_bn
,
nn
.
LeakyReLU
:
self
.
_count_relu
,
nn
.
AvgPool1d
:
self
.
_count_avgpool
,
nn
.
AvgPool2d
:
self
.
_count_avgpool
,
nn
.
AvgPool3d
:
self
.
_count_avgpool
,
nn
.
AdaptiveAvgPool1d
:
self
.
_count_adap_avgpool
,
nn
.
AdaptiveAvgPool2d
:
self
.
_count_adap_avgpool
,
nn
.
AdaptiveAvgPool3d
:
self
.
_count_adap_avgpool
,
nn
.
Upsample
:
self
.
_count_upsample
,
nn
.
UpsamplingBilinear2d
:
self
.
_count_upsample
,
nn
.
UpsamplingNearest2d
:
self
.
_count_upsample
})
self
.
_count_bias
=
True
flops
,
params
=
profile
(
model
,
inputs
=
(
inputs
,
),
custom_ops
=
custom_mask_ops
,
verbose
=
verbose
)
if
custom_ops
is
not
None
:
self
.
ops
.
update
(
custom_ops
)
self
.
mode
=
mode
self
.
results
=
[]
for
m
in
hook_module_list
:
m
.
_buffers
.
pop
(
"weight_mask"
)
# Remove registerd buffer on the model, and fixed following issue:
# https://github.com/Lyken17/pytorch-OpCounter/issues/96
for
m
in
model
.
modules
():
if
'total_ops'
in
m
.
_buffers
:
m
.
_buffers
.
pop
(
"total_ops"
)
if
'total_params'
in
m
.
_buffers
:
m
.
_buffers
.
pop
(
"total_params"
)
def
_push_result
(
self
,
result
):
self
.
results
.
append
(
result
)
return
flops
,
params
def
_get_result
(
self
,
m
,
flops
):
# assume weight is called `weight`, otherwise it's not applicable
# if user customize the operation, the callback function should
# return the dict result, inluding calculated flops, params and weight_shape.
def
count_convNd_mask
(
m
,
x
,
y
):
"""
The forward hook to count FLOPs and Parameters of convolution operation.
Parameters
----------
m : torch.nn.Module
convolution module to calculate the FLOPs and Parameters
x : torch.Tensor
input data
y : torch.Tensor
output data
"""
output_channel
=
y
.
size
()[
1
]
output_size
=
torch
.
zeros
(
y
.
size
()[
2
:]).
numel
()
kernel_size
=
torch
.
zeros
(
m
.
weight
.
size
()[
2
:]).
numel
()
result
=
{
'flops'
:
flops
,
'params'
:
_get_params
(
m
),
'weight_shape'
:
tuple
(
m
.
weight
.
size
())
if
hasattr
(
m
,
'weight'
)
else
0
,
}
return
result
def
_count_convNd
(
self
,
m
,
x
,
y
):
cin
=
m
.
in_channels
kernel_ops
=
m
.
weight
.
size
()[
2
]
*
m
.
weight
.
size
()[
3
]
output_size
=
torch
.
zeros
(
y
.
size
()[
2
:]).
numel
()
cout
=
y
.
size
()[
1
]
if
hasattr
(
m
,
'weight_mask'
):
cout
=
m
.
weight_mask
.
sum
()
//
(
cin
*
kernel_ops
)
total_ops
=
cout
*
output_size
*
kernel_ops
*
cin
//
m
.
groups
# cout x oW x oH
if
self
.
_count_bias
:
bias_flops
=
1
if
m
.
bias
is
not
None
else
0
total_ops
+=
cout
*
output_size
*
bias_flops
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_linear
(
self
,
m
,
x
,
y
):
out_features
=
m
.
out_features
if
hasattr
(
m
,
'weight_mask'
):
out_features
=
m
.
weight_mask
.
sum
()
//
m
.
in_features
total_ops
=
out_features
*
m
.
in_features
if
self
.
_count_bias
:
bias_flops
=
1
if
m
.
bias
is
not
None
else
0
total_ops
+=
out_features
*
bias_flops
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_bn
(
self
,
m
,
x
,
y
):
total_ops
=
2
*
x
[
0
].
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
def
_count_relu
(
self
,
m
,
x
,
y
):
total_ops
=
x
[
0
].
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
bias_flops
=
1
if
m
.
bias
is
not
None
else
0
def
_count_avgpool
(
self
,
m
,
x
,
y
):
total_ops
=
y
.
numel
()
return
self
.
_get_result
(
m
,
total_ops
)
if
m
.
weight_mask
is
not
None
:
output_channel
=
m
.
weight_mask
.
sum
()
//
(
m
.
in_channels
*
kernel_size
)
def
_count_adap_avgpool
(
self
,
m
,
x
,
y
):
kernel
=
torch
.
Tensor
([
*
(
x
[
0
].
shape
[
2
:])])
//
torch
.
Tensor
(
list
((
m
.
output_size
,))).
squeeze
()
total_add
=
int
(
torch
.
prod
(
kernel
))
total_div
=
1
kernel_ops
=
total_add
+
total_div
num_elements
=
y
.
numel
()
total_ops
=
kernel_ops
*
num_elements
total_ops
=
output_channel
*
output_size
*
(
m
.
in_channels
//
m
.
groups
*
kernel_size
+
bias_fl
ops
)
return
self
.
_get_result
(
m
,
total_
ops
)
m
.
total_ops
+=
torch
.
DoubleTensor
([
int
(
total_ops
)])
def
_count_upsample
(
self
,
m
,
x
,
y
):
if
m
.
mode
==
'linear'
:
total_ops
=
y
.
nelement
()
*
5
# 2 muls + 3 add
elif
m
.
mode
==
'bilinear'
:
# https://en.wikipedia.org/wiki/Bilinear_interpolation
total_ops
=
y
.
nelement
()
*
11
# 6 muls + 5 adds
elif
m
.
mode
==
'bicubic'
:
# https://en.wikipedia.org/wiki/Bicubic_interpolation
# Product matrix [4x4] x [4x4] x [4x4]
ops_solve_A
=
224
# 128 muls + 96 adds
ops_solve_p
=
35
# 16 muls + 12 adds + 4 muls + 3 adds
total_ops
=
y
.
nelement
()
*
(
ops_solve_A
+
ops_solve_p
)
elif
m
.
mode
==
'trilinear'
:
# https://en.wikipedia.org/wiki/Trilinear_interpolation
# can viewed as 2 bilinear + 1 linear
total_ops
=
y
.
nelement
()
*
(
13
*
2
+
5
)
else
:
total_ops
=
0
return
self
.
_get_result
(
m
,
total_ops
)
def
count_linear_mask
(
m
,
x
,
y
):
def
count_module
(
self
,
m
,
x
,
y
,
name
):
# assume x is tuple of single tensor
result
=
self
.
ops
[
type
(
m
)](
m
,
x
,
y
)
total_result
=
{
'name'
:
name
,
'input_size'
:
tuple
(
x
[
0
].
size
()),
'output_size'
:
tuple
(
y
.
size
()),
'module_type'
:
type
(
m
).
__name__
,
**
result
}
self
.
_push_result
(
total_result
)
def
sum_flops
(
self
):
return
sum
([
s
[
'flops'
]
for
s
in
self
.
results
])
def
sum_params
(
self
):
return
sum
({
s
[
'name'
]:
s
[
'params'
]
for
s
in
self
.
results
}.
values
())
def
format_results
(
self
):
table
=
PrettyTable
()
name_counter
=
Counter
([
s
[
'name'
]
for
s
in
self
.
results
])
has_multi_use
=
any
(
map
(
lambda
v
:
v
>
1
,
name_counter
.
values
()))
name_counter
=
Counter
()
# clear the counter to count from 0
headers
=
[
'Index'
,
'Name'
,
'Type'
,
'Weight Shape'
,
'FLOPs'
,
'#Params'
,
]
if
has_multi_use
:
headers
.
append
(
'#Call'
)
table
.
field_names
=
headers
for
i
,
result
in
enumerate
(
self
.
results
):
row_values
=
[
i
,
result
[
'name'
],
result
[
'module_type'
],
str
(
result
[
'weight_shape'
]),
result
[
'flops'
],
result
[
'params'
],
]
name_counter
[
result
[
'name'
]]
+=
1
if
has_multi_use
:
row_values
.
append
(
name_counter
[
result
[
'name'
]])
table
.
add_row
(
row_values
)
return
table
def
count_flops_params
(
model
,
x
,
custom_ops
=
None
,
verbose
=
True
,
mode
=
'default'
):
"""
The forward hook to count FLOPs and Parameters of linear transformation.
Count FLOPs and Params of the given model. This function would
identify the mask on the module and take the pruned shape into consideration.
Note that, for sturctured pruning, we only identify the remained filters
according to its mask, and do not take the pruned input channels into consideration,
so the calculated FLOPs will be larger than real number.
Parameters
----------
m : torch.nn.Module
linear to calculate the FLOPs and Parameters
x : torch.Tensor
input data
y : torch.Tensor
output data
---------
model : nn.Module
Target model.
x : tuple or tensor
The input shape of data (a tuple), a tensor or a tuple of tensor as input data.
custom_ops : dict
A mapping of (module -> torch.nn.Module : custom operation)
the custom operation is a callback funtion to calculate
the module flops and parameters, it will overwrite the default operation.
for reference, please see ``ops`` in ``ModelProfiler``.
verbose : bool
If False, mute detail information about modules. Default is True.
mode : str
the mode of how to collect information. If the mode is set to ``default``,
only the information of convolution and linear will be collected.
If the mode is set to ``full``, other operations will also be collected.
Returns
-------
tuple of int, int and dict
Representing total FLOPs, total parameters, and a detailed list of results respectively.
The list of results are a list of dict, each of which contains (name, module_type, weight_shape,
flops, params, input_size, output_size) as its keys.
"""
output_channel
=
y
.
numel
()
bias_flops
=
1
if
m
.
bias
is
not
None
else
0
assert
isinstance
(
x
,
tuple
)
or
isinstance
(
x
,
torch
.
Tensor
)
assert
mode
in
[
'default'
,
'full'
]
original_device
=
next
(
model
.
parameters
()).
device
training
=
model
.
training
if
isinstance
(
x
,
tuple
)
and
all
(
isinstance
(
t
,
int
)
for
t
in
x
):
x
=
(
torch
.
zeros
(
x
).
to
(
original_device
),
)
elif
torch
.
is_tensor
(
x
):
x
=
(
x
.
to
(
original_device
),
)
else
:
x
=
(
t
.
to
(
original_device
)
for
t
in
x
)
handler_collection
=
[]
profiler
=
ModelProfiler
(
custom_ops
,
mode
)
prev_m
=
None
for
name
,
m
in
model
.
named_modules
():
# dealing with weight mask here
if
isinstance
(
prev_m
,
PrunerModuleWrapper
):
# weight mask is set to weight mask of its parent (wrapper)
weight_mask
=
prev_m
.
weight_mask
m
.
weight_mask
=
weight_mask
prev_m
=
m
if
type
(
m
)
in
profiler
.
ops
:
# if a leaf node
_handler
=
m
.
register_forward_hook
(
functools
.
partial
(
profiler
.
count_module
,
name
=
name
))
handler_collection
.
append
(
_handler
)
model
.
eval
()
if
m
.
weight_mask
is
not
None
:
output_channel
=
m
.
weight_mask
.
sum
()
//
m
.
in_features
with
torch
.
no_grad
()
:
model
(
*
x
)
total_ops
=
output_channel
*
(
m
.
in_features
+
bias_flops
)
# restore origin status
for
name
,
m
in
model
.
named_modules
():
if
hasattr
(
m
,
'weight_mask'
):
delattr
(
m
,
'weight_mask'
)
m
.
total_ops
+=
torch
.
DoubleTensor
([
int
(
total_ops
)])
model
.
train
(
training
).
to
(
original_device
)
for
handler
in
handler_collection
:
handler
.
remove
()
if
verbose
:
# get detail information
print
(
profiler
.
format_results
())
print
(
f
'FLOPs total:
{
profiler
.
sum_flops
()
}
'
)
print
(
f
'#Params total:
{
profiler
.
sum_params
()
}
'
)
custom_mask_ops
=
{
nn
.
Conv1d
:
count_convNd_mask
,
nn
.
Conv2d
:
count_convNd_mask
,
nn
.
Conv3d
:
count_convNd_mask
,
nn
.
Linear
:
count_linear_mask
,
}
return
profiler
.
sum_flops
(),
profiler
.
sum_params
(),
profiler
.
results
\ No newline at end of file
nni/compression/pytorch/utils/mask_conflict.py
View file @
b40e3db7
...
...
@@ -36,9 +36,11 @@ def fix_mask_conflict(masks, model=None, dummy_input=None, traced=None):
# this traced model.
if
traced
is
None
:
assert
model
is
not
None
and
dummy_input
is
not
None
with
torch
.
onnx
.
set_training
(
model
,
False
):
# We need to trace the model in this way, else it will have problems
traced
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
training
=
model
.
training
model
.
eval
()
# We need to trace the model in eval mode
traced
=
torch
.
jit
.
trace
(
model
,
dummy_input
)
model
.
train
(
training
)
fix_group_mask
=
GroupMaskConflict
(
masks
,
model
,
dummy_input
,
traced
)
masks
=
fix_group_mask
.
fix_mask
()
...
...
nni/experiment/nni_client.py
View file @
b40e3db7
...
...
@@ -100,10 +100,7 @@ class TrialResult:
self
.
value
=
None
self
.
trialJobId
=
None
for
key
in
json_obj
.
keys
():
if
key
==
'id'
:
setattr
(
self
,
'trialJobId'
,
json_obj
[
key
])
elif
hasattr
(
self
,
key
):
setattr
(
self
,
key
,
json_obj
[
key
])
setattr
(
self
,
key
,
json_obj
[
key
])
self
.
value
=
json
.
loads
(
self
.
value
)
def
__repr__
(
self
):
...
...
@@ -220,10 +217,7 @@ class TrialJob:
self
.
finalMetricData
=
None
self
.
stderrPath
=
None
for
key
in
json_obj
.
keys
():
if
key
==
'id'
:
setattr
(
self
,
'trialJobId'
,
json_obj
[
key
])
elif
hasattr
(
self
,
key
):
setattr
(
self
,
key
,
json_obj
[
key
])
setattr
(
self
,
key
,
json_obj
[
key
])
if
self
.
hyperParameters
:
self
.
hyperParameters
=
[
TrialHyperParameters
(
json
.
loads
(
e
))
for
e
in
self
.
hyperParameters
]
if
self
.
finalMetricData
:
...
...
nni/runtime/msg_dispatcher.py
View file @
b40e3db7
...
...
@@ -39,7 +39,7 @@ def _sort_history(history):
# Tuner global variables
_next_parameter_id
=
0
_trial_params
=
{}
'''key:
trial job
ID; value: parameters'''
'''key:
parameter
ID; value: parameters'''
_customized_parameter_ids
=
set
()
...
...
@@ -114,7 +114,7 @@ class MsgDispatcher(MsgDispatcherBase):
data: a list of dictionaries, each of which has at least two keys, 'parameter' and 'value'
"""
for
entry
in
data
:
entry
[
'value'
]
=
entry
[
'value'
]
if
type
(
entry
[
'value'
])
is
str
else
json_tricks
.
dumps
(
entry
[
'value'
])
entry
[
'value'
]
=
entry
[
'value'
]
if
type
(
entry
[
'value'
])
is
str
else
json_tricks
.
dumps
(
entry
[
'value'
])
entry
[
'value'
]
=
json_tricks
.
loads
(
entry
[
'value'
])
self
.
tuner
.
import_data
(
data
)
...
...
@@ -182,8 +182,11 @@ class MsgDispatcher(MsgDispatcherBase):
customized
=
True
else
:
customized
=
False
if
id_
in
_trial_params
:
self
.
tuner
.
receive_trial_result
(
id_
,
_trial_params
[
id_
],
value
,
customized
=
customized
,
trial_job_id
=
data
.
get
(
'trial_job_id'
))
else
:
_logger
.
warning
(
'Find unknown job parameter id %s, maybe something goes wrong.'
,
_trial_params
[
id_
])
def
_handle_intermediate_metric_data
(
self
,
data
):
"""Call assessor to process intermediate results
...
...
nni/runtime/platform/__init__.py
View file @
b40e3db7
...
...
@@ -9,7 +9,7 @@ if trial_env_vars.NNI_PLATFORM is None:
from
.standalone
import
*
elif
trial_env_vars
.
NNI_PLATFORM
==
'unittest'
:
from
.test
import
*
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
):
elif
trial_env_vars
.
NNI_PLATFORM
in
(
'adl'
,
'local'
,
'remote'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
):
from
.local
import
*
else
:
raise
RuntimeError
(
'Unknown platform %s'
%
trial_env_vars
.
NNI_PLATFORM
)
nni/tools/nnictl/common_utils.py
View file @
b40e3db7
...
...
@@ -5,11 +5,14 @@ import os
import
sys
import
json
import
tempfile
import
time
import
socket
import
string
import
random
import
ruamel.yaml
as
yaml
import
psutil
import
filelock
import
glob
from
colorama
import
Fore
from
.constants
import
ERROR_INFO
,
NORMAL_INFO
,
WARNING_INFO
...
...
@@ -95,3 +98,36 @@ def generate_temp_dir():
temp_dir
=
generate_folder_name
()
os
.
makedirs
(
temp_dir
)
return
temp_dir
class
SimplePreemptiveLock
(
filelock
.
SoftFileLock
):
'''this is a lock support check lock expiration, if you do not need check expiration, you can use SoftFileLock'''
def
__init__
(
self
,
lock_file
,
stale
=-
1
):
super
(
__class__
,
self
).
__init__
(
lock_file
,
timeout
=-
1
)
self
.
_lock_file_name
=
'{}.{}'
.
format
(
self
.
_lock_file
,
os
.
getpid
())
self
.
_stale
=
stale
def
_acquire
(
self
):
open_mode
=
os
.
O_WRONLY
|
os
.
O_CREAT
|
os
.
O_EXCL
|
os
.
O_TRUNC
try
:
lock_file_names
=
glob
.
glob
(
self
.
_lock_file
+
'.*'
)
for
file_name
in
lock_file_names
:
if
os
.
path
.
exists
(
file_name
)
and
(
self
.
_stale
<
0
or
time
.
time
()
-
os
.
stat
(
file_name
).
st_mtime
<
self
.
_stale
):
return
None
fd
=
os
.
open
(
self
.
_lock_file_name
,
open_mode
)
except
(
IOError
,
OSError
):
pass
else
:
self
.
_lock_file_fd
=
fd
return
None
def
_release
(
self
):
os
.
close
(
self
.
_lock_file_fd
)
self
.
_lock_file_fd
=
None
try
:
os
.
remove
(
self
.
_lock_file_name
)
except
OSError
:
pass
return
None
def
get_file_lock
(
path
:
string
,
stale
=-
1
):
return
SimplePreemptiveLock
(
path
+
'.lock'
,
stale
=-
1
)
nni/tools/nnictl/config_schema.py
View file @
b40e3db7
...
...
@@ -124,7 +124,7 @@ common_schema = {
Optional
(
'maxExecDuration'
):
And
(
Regex
(
r
'^[1-9][0-9]*[s|m|h|d]$'
,
error
=
'ERROR: maxExecDuration format is [digit]{s,m,h,d}'
)),
Optional
(
'maxTrialNum'
):
setNumberRange
(
'maxTrialNum'
,
int
,
1
,
99999
),
'trainingServicePlatform'
:
setChoice
(
'trainingServicePlatform'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
),
'trainingServicePlatform'
,
'adl'
,
'remote'
,
'local'
,
'pai'
,
'kubeflow'
,
'frameworkcontroller'
,
'paiYarn'
,
'dlts'
,
'aml'
),
Optional
(
'searchSpacePath'
):
And
(
os
.
path
.
exists
,
error
=
SCHEMA_PATH_ERROR
%
'searchSpacePath'
),
Optional
(
'multiPhase'
):
setType
(
'multiPhase'
,
bool
),
Optional
(
'multiThread'
):
setType
(
'multiThread'
,
bool
),
...
...
@@ -262,6 +262,30 @@ aml_config_schema = {
}
}
adl_trial_schema
=
{
'trial'
:{
'codeDir'
:
setType
(
'codeDir'
,
str
),
'command'
:
setType
(
'command'
,
str
),
'gpuNum'
:
setNumberRange
(
'gpuNum'
,
int
,
0
,
99999
),
'image'
:
setType
(
'image'
,
str
),
Optional
(
'imagePullSecrets'
):
[{
'name'
:
setType
(
'name'
,
str
)
}],
Optional
(
'nfs'
):
{
'server'
:
setType
(
'server'
,
str
),
'path'
:
setType
(
'path'
,
str
),
'containerMountPath'
:
setType
(
'containerMountPath'
,
str
)
},
Optional
(
'adaptive'
):
setType
(
'adaptive'
,
bool
),
Optional
(
'checkpoint'
):
{
'storageClass'
:
setType
(
'storageClass'
,
str
),
'storageSize'
:
setType
(
'storageSize'
,
str
)
},
Optional
(
'cpuNum'
):
setNumberRange
(
'cpuNum'
,
int
,
0
,
99999
),
Optional
(
'memorySize'
):
setType
(
'memorySize'
,
str
)
}
}
kubeflow_trial_schema
=
{
'trial'
:
{
'codeDir'
:
setPathCheck
(
'codeDir'
),
...
...
@@ -404,6 +428,7 @@ machine_list_schema = {
}
training_service_schema_dict
=
{
'adl'
:
Schema
({
**
common_schema
,
**
adl_trial_schema
}),
'local'
:
Schema
({
**
common_schema
,
**
common_trial_schema
}),
'remote'
:
Schema
({
**
common_schema
,
**
common_trial_schema
,
**
machine_list_schema
,
**
remote_config_schema
}),
'pai'
:
Schema
({
**
common_schema
,
**
pai_trial_schema
,
**
pai_config_schema
}),
...
...
nni/tools/nnictl/config_utils.py
View file @
b40e3db7
...
...
@@ -4,8 +4,10 @@
import
os
import
json
import
shutil
import
time
from
.constants
import
NNICTL_HOME_DIR
from
.command_utils
import
print_error
from
.common_utils
import
get_file_lock
class
Config
:
'''a util class to load and save config'''
...
...
@@ -34,7 +36,7 @@ class Config:
if
self
.
config
:
try
:
with
open
(
self
.
config_file
,
'w'
)
as
file
:
json
.
dump
(
self
.
config
,
file
)
json
.
dump
(
self
.
config
,
file
,
indent
=
4
)
except
IOError
as
error
:
print
(
'Error:'
,
error
)
return
...
...
@@ -54,39 +56,53 @@ class Experiments:
def
__init__
(
self
,
home_dir
=
NNICTL_HOME_DIR
):
os
.
makedirs
(
home_dir
,
exist_ok
=
True
)
self
.
experiment_file
=
os
.
path
.
join
(
home_dir
,
'.experiment'
)
self
.
experiments
=
self
.
read_file
()
self
.
lock
=
get_file_lock
(
self
.
experiment_file
,
stale
=
2
)
with
self
.
lock
:
self
.
experiments
=
self
.
read_file
()
def
add_experiment
(
self
,
expId
,
port
,
startTime
,
file_name
,
platform
,
experiment_name
,
endTime
=
'N/A'
,
status
=
'INITIALIZED'
):
'''set {key:value} paris to self.experiment'''
self
.
experiments
[
expId
]
=
{}
self
.
experiments
[
expId
][
'port'
]
=
port
self
.
experiments
[
expId
][
'startTime'
]
=
startTime
self
.
experiments
[
expId
][
'endTime'
]
=
endTime
self
.
experiments
[
expId
][
'status'
]
=
status
self
.
experiments
[
expId
][
'fileName'
]
=
file_name
self
.
experiments
[
expId
][
'platform'
]
=
platform
self
.
experiments
[
expId
][
'experimentName'
]
=
experiment_name
self
.
write_file
()
def
add_experiment
(
self
,
expId
,
port
,
startTime
,
platform
,
experiment_name
,
endTime
=
'N/A'
,
status
=
'INITIALIZED'
,
tag
=
[],
pid
=
None
,
webuiUrl
=
[],
logDir
=
[]):
'''set {key:value} pairs to self.experiment'''
with
self
.
lock
:
self
.
experiments
=
self
.
read_file
()
self
.
experiments
[
expId
]
=
{}
self
.
experiments
[
expId
][
'id'
]
=
expId
self
.
experiments
[
expId
][
'port'
]
=
port
self
.
experiments
[
expId
][
'startTime'
]
=
startTime
self
.
experiments
[
expId
][
'endTime'
]
=
endTime
self
.
experiments
[
expId
][
'status'
]
=
status
self
.
experiments
[
expId
][
'platform'
]
=
platform
self
.
experiments
[
expId
][
'experimentName'
]
=
experiment_name
self
.
experiments
[
expId
][
'tag'
]
=
tag
self
.
experiments
[
expId
][
'pid'
]
=
pid
self
.
experiments
[
expId
][
'webuiUrl'
]
=
webuiUrl
self
.
experiments
[
expId
][
'logDir'
]
=
logDir
self
.
write_file
()
def
update_experiment
(
self
,
expId
,
key
,
value
):
'''Update experiment'''
if
expId
not
in
self
.
experiments
:
return
False
self
.
experiments
[
expId
][
key
]
=
value
self
.
write_file
()
return
True
with
self
.
lock
:
self
.
experiments
=
self
.
read_file
()
if
expId
not
in
self
.
experiments
:
return
False
self
.
experiments
[
expId
][
key
]
=
value
self
.
write_file
()
return
True
def
remove_experiment
(
self
,
expId
):
'''remove an experiment by id'''
if
expId
in
self
.
experiments
:
fileName
=
self
.
experiments
.
pop
(
expId
).
get
(
'fileName'
)
if
fileName
:
logPath
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
fileName
)
try
:
shutil
.
rmtree
(
logPath
)
except
FileNotFoundError
:
print_error
(
'{0} does not exist.'
.
format
(
logPath
))
self
.
write_file
()
with
self
.
lock
:
self
.
experiments
=
self
.
read_file
()
if
expId
in
self
.
experiments
:
self
.
experiments
.
pop
(
expId
)
fileName
=
expId
if
fileName
:
logPath
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
fileName
)
try
:
shutil
.
rmtree
(
logPath
)
except
FileNotFoundError
:
print_error
(
'{0} does not exist.'
.
format
(
logPath
))
self
.
write_file
()
def
get_all_experiments
(
self
):
'''return all of experiments'''
...
...
@@ -96,7 +112,7 @@ class Experiments:
'''save config to local file'''
try
:
with
open
(
self
.
experiment_file
,
'w'
)
as
file
:
json
.
dump
(
self
.
experiments
,
file
)
json
.
dump
(
self
.
experiments
,
file
,
indent
=
4
)
except
IOError
as
error
:
print
(
'Error:'
,
error
)
return
''
...
...
nni/tools/nnictl/constants.py
View file @
b40e3db7
...
...
@@ -4,7 +4,7 @@
import
os
from
colorama
import
Fore
NNICTL_HOME_DIR
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'
.local'
,
'nnictl
'
)
NNICTL_HOME_DIR
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'
nni-experiments
'
)
NNI_HOME_DIR
=
os
.
path
.
join
(
os
.
path
.
expanduser
(
'~'
),
'nni-experiments'
)
...
...
@@ -64,21 +64,21 @@ TRIAL_MONITOR_TAIL = '----------------------------------------------------------
INSTALLABLE_PACKAGE_META
=
{
'SMAC'
:
{
'type'
:
'tuner'
,
'class_name'
:
'nni.smac_tuner.smac_tuner.SMACTuner'
,
'class_name'
:
'nni.
algorithms.hpo.
smac_tuner.smac_tuner.SMACTuner'
,
'code_sub_dir'
:
'smac_tuner'
,
'class_args_validator'
:
'nni.smac_tuner.smac_tuner.SMACClassArgsValidator'
'class_args_validator'
:
'nni.
algorithms.hpo.
smac_tuner.smac_tuner.SMACClassArgsValidator'
},
'BOHB'
:
{
'type'
:
'advisor'
,
'class_name'
:
'nni.bohb_advisor.bohb_advisor.BOHB'
,
'class_name'
:
'nni.
algorithms.hpo.
bohb_advisor.bohb_advisor.BOHB'
,
'code_sub_dir'
:
'bohb_advisor'
,
'class_args_validator'
:
'nni.bohb_advisor.bohb_advisor.BOHBClassArgsValidator'
'class_args_validator'
:
'nni.
algorithms.hpo.
bohb_advisor.bohb_advisor.BOHBClassArgsValidator'
},
'PPOTuner'
:
{
'type'
:
'tuner'
,
'class_name'
:
'nni.ppo_tuner.ppo_tuner.PPOTuner'
,
'class_name'
:
'nni.
algorithms.hpo.
ppo_tuner.ppo_tuner.PPOTuner'
,
'code_sub_dir'
:
'ppo_tuner'
,
'class_args_validator'
:
'nni.ppo_tuner.ppo_tuner.PPOClassArgsValidator'
'class_args_validator'
:
'nni.
algorithms.hpo.
ppo_tuner.ppo_tuner.PPOClassArgsValidator'
}
}
...
...
nni/tools/nnictl/launcher.py
View file @
b40e3db7
...
...
@@ -23,10 +23,11 @@ from .constants import NNICTL_HOME_DIR, ERROR_INFO, REST_TIME_OUT, EXPERIMENT_SU
from
.command_utils
import
check_output_command
,
kill_command
from
.nnictl_utils
import
update_experiment
def
get_log_path
(
config_file_name
):
def
get_log_path
(
experiment_id
):
'''generate stdout and stderr log path'''
stdout_full_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
config_file_name
,
'stdout'
)
stderr_full_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
config_file_name
,
'stderr'
)
os
.
makedirs
(
os
.
path
.
join
(
NNICTL_HOME_DIR
,
experiment_id
,
'log'
),
exist_ok
=
True
)
stdout_full_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
experiment_id
,
'log'
,
'nnictl_stdout.log'
)
stderr_full_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
experiment_id
,
'log'
,
'nnictl_stderr.log'
)
return
stdout_full_path
,
stderr_full_path
def
print_log_content
(
config_file_name
):
...
...
@@ -38,7 +39,7 @@ def print_log_content(config_file_name):
print_normal
(
' Stderr:'
)
print
(
check_output_command
(
stderr_full_path
))
def
start_rest_server
(
port
,
platform
,
mode
,
config_file_name
,
foreground
=
False
,
experiment_id
=
None
,
log_dir
=
None
,
log_level
=
None
):
def
start_rest_server
(
port
,
platform
,
mode
,
experiment_id
,
foreground
=
False
,
log_dir
=
None
,
log_level
=
None
):
'''Run nni manager process'''
if
detect_port
(
port
):
print_error
(
'Port %s is used by another process, please reset the port!
\n
'
\
...
...
@@ -63,7 +64,8 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
node_command
=
os
.
path
.
join
(
entry_dir
,
'node.exe'
)
else
:
node_command
=
os
.
path
.
join
(
entry_dir
,
'node'
)
cmds
=
[
node_command
,
'--max-old-space-size=4096'
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
]
cmds
=
[
node_command
,
'--max-old-space-size=4096'
,
entry_file
,
'--port'
,
str
(
port
),
'--mode'
,
platform
,
\
'--experiment_id'
,
experiment_id
]
if
mode
==
'view'
:
cmds
+=
[
'--start_mode'
,
'resume'
]
cmds
+=
[
'--readonly'
,
'true'
]
...
...
@@ -73,13 +75,12 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
cmds
+=
[
'--log_dir'
,
log_dir
]
if
log_level
is
not
None
:
cmds
+=
[
'--log_level'
,
log_level
]
if
mode
in
[
'resume'
,
'view'
]:
cmds
+=
[
'--experiment_id'
,
experiment_id
]
if
foreground
:
cmds
+=
[
'--foreground'
,
'true'
]
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
stdout_full_path
,
stderr_full_path
=
get_log_path
(
experiment_id
)
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
start_time
=
time
.
time
()
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
start_time
))
#add time information in the header of log files
log_header
=
LOG_HEADER
%
str
(
time_now
)
stdout_file
.
write
(
log_header
)
...
...
@@ -95,7 +96,7 @@ def start_rest_server(port, platform, mode, config_file_name, foreground=False,
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
PIPE
,
stderr
=
PIPE
)
else
:
process
=
Popen
(
cmds
,
cwd
=
entry_dir
,
stdout
=
stdout_file
,
stderr
=
stderr_file
)
return
process
,
str
(
time_now
)
return
process
,
int
(
start_time
*
1000
)
def
set_trial_config
(
experiment_config
,
port
,
config_file_name
):
'''set trial configuration'''
...
...
@@ -136,6 +137,14 @@ def set_local_config(experiment_config, port, config_file_name):
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
None
def
set_adl_config
(
experiment_config
,
port
,
config_file_name
):
'''set adl configuration'''
result
,
message
=
setNNIManagerIp
(
experiment_config
,
port
,
config_file_name
)
if
not
result
:
return
result
,
message
#set trial_config
return
set_trial_config
(
experiment_config
,
port
,
config_file_name
),
None
def
set_remote_config
(
experiment_config
,
port
,
config_file_name
):
'''Call setClusterMetadata to pass trial'''
#set machine_list
...
...
@@ -393,7 +402,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
'''call set_cluster_metadata for specific platform'''
print_normal
(
'Setting {0} config...'
.
format
(
platform
))
config_result
,
err_msg
=
None
,
None
if
platform
==
'local'
:
if
platform
==
'adl'
:
config_result
,
err_msg
=
set_adl_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'local'
:
config_result
,
err_msg
=
set_local_config
(
experiment_config
,
port
,
config_file_name
)
elif
platform
==
'remote'
:
config_result
,
err_msg
=
set_remote_config
(
experiment_config
,
port
,
config_file_name
)
...
...
@@ -422,9 +433,9 @@ def set_platform_config(platform, experiment_config, port, config_file_name, res
raise
Exception
(
ERROR_INFO
%
'Rest server stopped!'
)
exit
(
1
)
def
launch_experiment
(
args
,
experiment_config
,
mode
,
config_file_name
,
experiment_id
=
None
):
def
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
):
'''follow steps to start rest server and start experiment'''
nni_config
=
Config
(
config_file_name
)
nni_config
=
Config
(
experiment_id
)
# check packages for tuner
package_name
,
module_name
=
None
,
None
if
experiment_config
.
get
(
'tuner'
)
and
experiment_config
[
'tuner'
].
get
(
'builtinTunerName'
):
...
...
@@ -435,15 +446,15 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
module_name
,
_
=
get_builtin_module_class_name
(
'advisors'
,
package_name
)
if
package_name
and
module_name
:
try
:
stdout_full_path
,
stderr_full_path
=
get_log_path
(
config_file_name
)
stdout_full_path
,
stderr_full_path
=
get_log_path
(
experiment_id
)
with
open
(
stdout_full_path
,
'a+'
)
as
stdout_file
,
open
(
stderr_full_path
,
'a+'
)
as
stderr_file
:
check_call
([
sys
.
executable
,
'-c'
,
'import %s'
%
(
module_name
)],
stdout
=
stdout_file
,
stderr
=
stderr_file
)
except
CalledProcessError
:
print_error
(
'some errors happen when import package %s.'
%
(
package_name
))
print_log_content
(
config_file_name
)
print_log_content
(
experiment_id
)
if
package_name
in
INSTALLABLE_PACKAGE_META
:
print_error
(
'If %s is not installed, it should be installed through '
\
'
\'
nnictl package install --name %s
\'
'
%
(
package_name
,
package_name
))
'
\'
nnictl package install --name %s
\'
'
%
(
package_name
,
package_name
))
exit
(
1
)
log_dir
=
experiment_config
[
'logDir'
]
if
experiment_config
.
get
(
'logDir'
)
else
None
log_level
=
experiment_config
[
'logLevel'
]
if
experiment_config
.
get
(
'logLevel'
)
else
None
...
...
@@ -455,7 +466,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
log_level
=
'debug'
# start rest server
rest_process
,
start_time
=
start_rest_server
(
args
.
port
,
experiment_config
[
'trainingServicePlatform'
],
\
mode
,
config_file_name
,
foreground
,
experiment_id
,
log_dir
,
log_level
)
mode
,
experiment_i
d
,
foregroun
d
,
log_dir
,
log_level
)
nni_config
.
set_config
(
'restServerPid'
,
rest_process
.
pid
)
# Deal with annotation
if
experiment_config
.
get
(
'useAnnotation'
):
...
...
@@ -481,7 +492,7 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
print_normal
(
'Successfully started Restful server!'
)
else
:
print_error
(
'Restful server start failed!'
)
print_log_content
(
config_file_name
)
print_log_content
(
experiment_id
)
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
...
...
@@ -490,21 +501,25 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
if
mode
!=
'view'
:
# set platform configuration
set_platform_config
(
experiment_config
[
'trainingServicePlatform'
],
experiment_config
,
args
.
port
,
\
config_file_name
,
rest_process
)
experiment_id
,
rest_process
)
# start a new experiment
print_normal
(
'Starting experiment...'
)
# save experiment information
nnictl_experiment_config
=
Experiments
()
nnictl_experiment_config
.
add_experiment
(
experiment_id
,
args
.
port
,
start_time
,
experiment_config
[
'trainingServicePlatform'
],
experiment_config
[
'experimentName'
],
pid
=
rest_process
.
pid
,
logDir
=
log_dir
)
# set debug configuration
if
mode
!=
'view'
and
experiment_config
.
get
(
'debug'
)
is
None
:
experiment_config
[
'debug'
]
=
args
.
debug
response
=
set_experiment
(
experiment_config
,
mode
,
args
.
port
,
config_file_name
)
response
=
set_experiment
(
experiment_config
,
mode
,
args
.
port
,
experiment_id
)
if
response
:
if
experiment_id
is
None
:
experiment_id
=
json
.
loads
(
response
.
text
).
get
(
'experiment_id'
)
nni_config
.
set_config
(
'experimentId'
,
experiment_id
)
else
:
print_error
(
'Start experiment failed!'
)
print_log_content
(
config_file_name
)
print_log_content
(
experiment_id
)
try
:
kill_command
(
rest_process
.
pid
)
except
Exception
:
...
...
@@ -516,12 +531,6 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
web_ui_url_list
=
get_local_urls
(
args
.
port
)
nni_config
.
set_config
(
'webuiUrl'
,
web_ui_url_list
)
# save experiment information
nnictl_experiment_config
=
Experiments
()
nnictl_experiment_config
.
add_experiment
(
experiment_id
,
args
.
port
,
start_time
,
config_file_name
,
experiment_config
[
'trainingServicePlatform'
],
experiment_config
[
'experimentName'
])
print_normal
(
EXPERIMENT_SUCCESS_INFO
%
(
experiment_id
,
' '
.
join
(
web_ui_url_list
)))
if
mode
!=
'view'
and
args
.
foreground
:
try
:
...
...
@@ -534,8 +543,9 @@ def launch_experiment(args, experiment_config, mode, config_file_name, experimen
def
create_experiment
(
args
):
'''start a new experiment'''
config_file_name
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
nni_config
=
Config
(
config_file_name
)
experiment_id
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
nni_config
=
Config
(
experiment_id
)
nni_config
.
set_config
(
'experimentId'
,
experiment_id
)
config_path
=
os
.
path
.
abspath
(
args
.
config
)
if
not
os
.
path
.
exists
(
config_path
):
print_error
(
'Please set correct config path!'
)
...
...
@@ -550,9 +560,9 @@ def create_experiment(args):
nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
try
:
launch_experiment
(
args
,
experiment_config
,
'new'
,
config_file_name
)
launch_experiment
(
args
,
experiment_config
,
'new'
,
experiment_id
)
except
Exception
as
exception
:
nni_config
=
Config
(
config_file_name
)
nni_config
=
Config
(
experiment_id
)
restServerPid
=
nni_config
.
get_config
(
'restServerPid'
)
if
restServerPid
:
kill_command
(
restServerPid
)
...
...
@@ -579,17 +589,13 @@ def manage_stopped_experiment(args, mode):
exit
(
1
)
experiment_id
=
args
.
id
print_normal
(
'{0} experiment {1}...'
.
format
(
mode
,
experiment_id
))
nni_config
=
Config
(
experiment_d
ict
[
experiment_id
][
'fileName'
]
)
nni_config
=
Config
(
experiment_
i
d
)
experiment_config
=
nni_config
.
get_config
(
'experimentConfig'
)
experiment_id
=
nni_config
.
get_config
(
'experimentId'
)
new_config_file_name
=
''
.
join
(
random
.
sample
(
string
.
ascii_letters
+
string
.
digits
,
8
))
new_nni_config
=
Config
(
new_config_file_name
)
new_nni_config
.
set_config
(
'experimentConfig'
,
experiment_config
)
new_nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
nni_config
.
set_config
(
'restServerPort'
,
args
.
port
)
try
:
launch_experiment
(
args
,
experiment_config
,
mode
,
new_config_file_name
,
experiment_id
)
launch_experiment
(
args
,
experiment_config
,
mode
,
experiment_id
)
except
Exception
as
exception
:
nni_config
=
Config
(
new_config_file_name
)
nni_config
=
Config
(
experiment_id
)
restServerPid
=
nni_config
.
get_config
(
'restServerPid'
)
if
restServerPid
:
kill_command
(
restServerPid
)
...
...
nni/tools/nnictl/launcher_utils.py
View file @
b40e3db7
...
...
@@ -32,6 +32,8 @@ def parse_time(time):
def
parse_path
(
experiment_config
,
config_path
):
'''Parse path in config file'''
expand_path
(
experiment_config
,
'searchSpacePath'
)
if
experiment_config
.
get
(
'logDir'
):
expand_path
(
experiment_config
,
'logDir'
)
if
experiment_config
.
get
(
'trial'
):
expand_path
(
experiment_config
[
'trial'
],
'codeDir'
)
if
experiment_config
[
'trial'
].
get
(
'authFile'
):
...
...
@@ -65,6 +67,8 @@ def parse_path(experiment_config, config_path):
root_path
=
os
.
path
.
dirname
(
config_path
)
if
experiment_config
.
get
(
'searchSpacePath'
):
parse_relative_path
(
root_path
,
experiment_config
,
'searchSpacePath'
)
if
experiment_config
.
get
(
'logDir'
):
parse_relative_path
(
root_path
,
experiment_config
,
'logDir'
)
if
experiment_config
.
get
(
'trial'
):
parse_relative_path
(
root_path
,
experiment_config
[
'trial'
],
'codeDir'
)
if
experiment_config
[
'trial'
].
get
(
'authFile'
):
...
...
nni/tools/nnictl/nnictl_utils.py
View file @
b40e3db7
...
...
@@ -10,6 +10,7 @@ import re
import
shutil
import
subprocess
from
functools
import
cmp_to_key
import
traceback
from
datetime
import
datetime
,
timezone
from
subprocess
import
Popen
from
pyhdfs
import
HdfsClient
...
...
@@ -21,6 +22,7 @@ from .config_utils import Config, Experiments
from
.constants
import
NNICTL_HOME_DIR
,
NNI_HOME_DIR
,
EXPERIMENT_INFORMATION_FORMAT
,
EXPERIMENT_DETAIL_FORMAT
,
\
EXPERIMENT_MONITOR_INFO
,
TRIAL_MONITOR_HEAD
,
TRIAL_MONITOR_CONTENT
,
TRIAL_MONITOR_TAIL
,
REST_TIME_OUT
from
.common_utils
import
print_normal
,
print_error
,
print_warning
,
detect_process
,
get_yml_content
,
generate_temp_dir
from
.common_utils
import
print_green
from
.command_utils
import
check_output_command
,
kill_command
from
.ssh_utils
import
create_ssh_sftp_client
,
remove_remote_directory
...
...
@@ -28,7 +30,7 @@ def get_experiment_time(port):
'''get the startTime and endTime of an experiment'''
response
=
rest_get
(
experiment_url
(
port
),
REST_TIME_OUT
)
if
response
and
check_response
(
response
):
content
=
convert_time_stamp_to_date
(
json
.
loads
(
response
.
text
)
)
content
=
json
.
loads
(
response
.
text
)
return
content
.
get
(
'startTime'
),
content
.
get
(
'endTime'
)
return
None
,
None
...
...
@@ -48,20 +50,11 @@ def update_experiment():
for
key
in
experiment_dict
.
keys
():
if
isinstance
(
experiment_dict
[
key
],
dict
):
if
experiment_dict
[
key
].
get
(
'status'
)
!=
'STOPPED'
:
nni_config
=
Config
(
experiment_dict
[
key
][
'fileName'
]
)
nni_config
=
Config
(
key
)
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
if
not
detect_process
(
rest_pid
):
experiment_config
.
update_experiment
(
key
,
'status'
,
'STOPPED'
)
continue
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
startTime
,
endTime
=
get_experiment_time
(
rest_port
)
if
startTime
:
experiment_config
.
update_experiment
(
key
,
'startTime'
,
startTime
)
if
endTime
:
experiment_config
.
update_experiment
(
key
,
'endTime'
,
endTime
)
status
=
get_experiment_status
(
rest_port
)
if
status
:
experiment_config
.
update_experiment
(
key
,
'status'
,
status
)
def
check_experiment_id
(
args
,
update
=
True
):
'''check if the id is valid
...
...
@@ -182,9 +175,7 @@ def get_config_filename(args):
if
experiment_id
is
None
:
print_error
(
'Please set correct experiment id.'
)
exit
(
1
)
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
return
experiment_dict
[
experiment_id
][
'fileName'
]
return
experiment_id
def
get_experiment_port
(
args
):
'''get the port of experiment'''
...
...
@@ -226,11 +217,9 @@ def stop_experiment(args):
exit
(
1
)
experiment_id_list
=
parse_ids
(
args
)
if
experiment_id_list
:
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
for
experiment_id
in
experiment_id_list
:
print_normal
(
'Stopping experiment %s'
%
experiment_id
)
nni_config
=
Config
(
experiment_d
ict
[
experiment_id
][
'fileName'
]
)
nni_config
=
Config
(
experiment_
i
d
)
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
if
rest_pid
:
kill_command
(
rest_pid
)
...
...
@@ -243,9 +232,6 @@ def stop_experiment(args):
print_error
(
exception
)
nni_config
.
set_config
(
'tensorboardPidList'
,
[])
print_normal
(
'Stop experiment success.'
)
experiment_config
.
update_experiment
(
experiment_id
,
'status'
,
'STOPPED'
)
time_now
=
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
time
.
time
()))
experiment_config
.
update_experiment
(
experiment_id
,
'endTime'
,
str
(
time_now
))
def
trial_ls
(
args
):
'''List trial'''
...
...
@@ -372,6 +358,40 @@ def log_stderr(args):
'''get stderr log'''
log_internal
(
args
,
'stderr'
)
def
log_trial_adl_helper
(
args
,
experiment_id
):
# adljob_id format should be consistent to the one in "adlTrainingService.ts":
# const adlJobName: string = `nni-exp-${this.experimentId}-trial-${trialJobId}`.toLowerCase();
adlJobName
=
"nni-exp-{}-trial-{}"
.
format
(
experiment_id
,
args
.
trial_id
).
lower
()
print_warning
(
'Note that no log will show when trial is pending or done (succeeded or failed). '
'You can retry the command.'
)
print_green
(
'>>> Trial log streaming:'
)
try
:
subprocess
.
run
(
[
"kubectl"
,
"logs"
,
"-l"
,
"adaptdl/job=%s"
%
adlJobName
,
"-f"
# Follow the stream
],
# TODO: support remaining argument, uncomment the lines in nnictl.py
)
# TODO: emulate tee behaviors, not necessary tho.
except
KeyboardInterrupt
:
pass
except
Exception
:
print_error
(
'Error! Please check kubectl:'
)
traceback
.
print_exc
()
exit
(
1
)
finally
:
print_green
(
'<<< [adlJobName:%s]'
%
adlJobName
)
nni_manager_collection_path
=
os
.
path
.
expanduser
(
'~/nni-experiments/%s/trials/%s/stdout_log_collection.log'
%
(
experiment_id
,
args
.
trial_id
))
print_green
(
'>>> (Optional) How to persist the complete trial log locally:'
)
print
(
'Please ensure `logCollection: http` '
'exists in the experiment configuration yaml. '
'After trial done, you can check it from the file below:
\n
%s'
%
nni_manager_collection_path
)
def
log_trial
(
args
):
''''get trial log path'''
trial_id_path_dict
=
{}
...
...
@@ -388,16 +408,24 @@ def log_trial(args):
if
response
and
check_response
(
response
):
content
=
json
.
loads
(
response
.
text
)
for
trial
in
content
:
trial_id_list
.
append
(
trial
.
get
(
'
i
d'
))
trial_id_list
.
append
(
trial
.
get
(
'
trialJobI
d'
))
if
trial
.
get
(
'logPath'
):
trial_id_path_dict
[
trial
.
get
(
'
i
d'
)]
=
trial
[
'logPath'
]
trial_id_path_dict
[
trial
.
get
(
'
trialJobI
d'
)]
=
trial
[
'logPath'
]
else
:
print_error
(
'Restful server is not running...'
)
exit
(
1
)
is_adl
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trainingServicePlatform'
)
==
'adl'
if
is_adl
and
not
args
.
trial_id
:
print_error
(
'Trial ID is required to retrieve the log for adl. Please specify it with "--trial_id".'
)
exit
(
1
)
if
args
.
trial_id
:
if
args
.
trial_id
not
in
trial_id_list
:
print_error
(
'Trial id {0} not correct, please check your command!'
.
format
(
args
.
trial_id
))
exit
(
1
)
if
is_adl
:
log_trial_adl_helper
(
args
,
nni_config
.
get_config
(
'experimentId'
))
# adl has its own way to log trial, and it thus returns right after the helper returns
return
if
trial_id_path_dict
.
get
(
args
.
trial_id
):
print_normal
(
'id:'
+
args
.
trial_id
+
' path:'
+
trial_id_path_dict
[
args
.
trial_id
])
else
:
...
...
@@ -429,7 +457,7 @@ def webui_nas(args):
if
sys
.
platform
==
'win32'
:
node_command
=
os
.
path
.
join
(
entry_dir
,
'node.exe'
)
else
:
node_command
=
'node'
node_command
=
os
.
path
.
join
(
entry_dir
,
'node'
)
cmds
=
[
node_command
,
'--max-old-space-size=4096'
,
entry_file
,
'--port'
,
str
(
args
.
port
),
'--logdir'
,
args
.
logdir
]
subprocess
.
run
(
cmds
,
cwd
=
entry_dir
)
except
KeyboardInterrupt
:
...
...
@@ -509,7 +537,7 @@ def experiment_clean(args):
else
:
break
for
experiment_id
in
experiment_id_list
:
nni_config
=
Config
(
experiment_d
ict
[
experiment_id
][
'fileName'
]
)
nni_config
=
Config
(
experiment_
i
d
)
platform
=
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trainingServicePlatform'
)
experiment_id
=
nni_config
.
get_config
(
'experimentId'
)
if
platform
==
'remote'
:
...
...
@@ -624,18 +652,15 @@ def experiment_list(args):
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'port'
],
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
])
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'startTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'startTime'
],
int
)
else
experiment_dict
[
key
][
'startTime'
],
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'endTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'endTime'
],
int
)
else
experiment_dict
[
key
][
'endTime'
])
print
(
EXPERIMENT_INFORMATION_FORMAT
%
experiment_information
)
return
experiment_id_list
def
get_time_interval
(
time1
,
time2
):
'''get the interval of two times'''
try
:
#convert time to timestamp
time1
=
time
.
mktime
(
time
.
strptime
(
time1
,
'%Y/%m/%d %H:%M:%S'
))
time2
=
time
.
mktime
(
time
.
strptime
(
time2
,
'%Y/%m/%d %H:%M:%S'
))
seconds
=
(
datetime
.
fromtimestamp
(
time2
)
-
datetime
.
fromtimestamp
(
time1
)).
seconds
seconds
=
int
((
time2
-
time1
)
/
1000
)
#convert seconds to day:hour:minute:second
days
=
seconds
/
86400
seconds
%=
86400
...
...
@@ -664,8 +689,8 @@ def show_experiment_info():
return
for
key
in
experiment_id_list
:
print
(
EXPERIMENT_MONITOR_INFO
%
(
key
,
experiment_dict
[
key
][
'status'
],
experiment_dict
[
key
][
'port'
],
\
experiment_dict
[
key
].
get
(
'platform'
),
experiment_dict
[
key
][
'startTime'
],
\
get_time_interval
(
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
])))
experiment_dict
[
key
].
get
(
'platform'
),
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
experiment_dict
[
key
][
'startTime'
]
/
1000
))
if
isinstance
(
experiment_dict
[
key
][
'startTime'
],
int
)
else
experiment_dict
[
key
][
'startTime'
],
\
get_time_interval
(
experiment_dict
[
key
][
'startTime'
],
experiment_dict
[
key
][
'endTime'
])))
print
(
TRIAL_MONITOR_HEAD
)
running
,
response
=
check_rest_server_quick
(
experiment_dict
[
key
][
'port'
])
if
running
:
...
...
@@ -674,7 +699,7 @@ def show_experiment_info():
content
=
json
.
loads
(
response
.
text
)
for
index
,
value
in
enumerate
(
content
):
content
[
index
]
=
convert_time_stamp_to_date
(
value
)
print
(
TRIAL_MONITOR_CONTENT
%
(
content
[
index
].
get
(
'
i
d'
),
content
[
index
].
get
(
'startTime'
),
\
print
(
TRIAL_MONITOR_CONTENT
%
(
content
[
index
].
get
(
'
trialJobI
d'
),
content
[
index
].
get
(
'startTime'
),
\
content
[
index
].
get
(
'endTime'
),
content
[
index
].
get
(
'status'
)))
print
(
TRIAL_MONITOR_TAIL
)
...
...
@@ -747,7 +772,7 @@ def export_trials_data(args):
return
intermediate_results
=
groupby_trial_id
(
json
.
loads
(
intermediate_results_response
.
text
))
for
record
in
content
:
record
[
'intermediate'
]
=
intermediate_results
[
record
[
'
i
d'
]]
record
[
'intermediate'
]
=
intermediate_results
[
record
[
'
trialJobI
d'
]]
if
args
.
type
==
'json'
:
with
open
(
args
.
path
,
'w'
)
as
file
:
file
.
write
(
json
.
dumps
(
content
))
...
...
@@ -759,9 +784,9 @@ def export_trials_data(args):
formated_record
[
'intermediate'
]
=
'['
+
','
.
join
(
record
[
'intermediate'
])
+
']'
record_value
=
json
.
loads
(
record
[
'value'
])
if
not
isinstance
(
record_value
,
(
float
,
int
)):
formated_record
.
update
({
**
record
[
'parameter'
],
**
record_value
,
**
{
'
i
d'
:
record
[
'
i
d'
]}})
formated_record
.
update
({
**
record
[
'parameter'
],
**
record_value
,
**
{
'
trialJobI
d'
:
record
[
'
trialJobI
d'
]}})
else
:
formated_record
.
update
({
**
record
[
'parameter'
],
**
{
'reward'
:
record_value
,
'
i
d'
:
record
[
'
i
d'
]}})
formated_record
.
update
({
**
record
[
'parameter'
],
**
{
'reward'
:
record_value
,
'
trialJobI
d'
:
record
[
'
trialJobI
d'
]}})
trial_records
.
append
(
formated_record
)
if
not
trial_records
:
print_error
(
'No trial results collected! Please check your trial log...'
)
...
...
@@ -806,7 +831,7 @@ def save_experiment(args):
print_error
(
'Can only save stopped experiment!'
)
exit
(
1
)
print_normal
(
'Saving...'
)
nni_config
=
Config
(
experiment_dict
[
args
.
id
][
'fileName'
]
)
nni_config
=
Config
(
args
.
id
)
logDir
=
os
.
path
.
join
(
NNI_HOME_DIR
,
args
.
id
)
if
nni_config
.
get_config
(
'logDir'
):
logDir
=
os
.
path
.
join
(
nni_config
.
get_config
(
'logDir'
),
args
.
id
)
...
...
@@ -829,8 +854,8 @@ def save_experiment(args):
except
IOError
:
print_error
(
'Write file to %s failed!'
%
os
.
path
.
join
(
temp_nnictl_dir
,
'.experiment'
))
exit
(
1
)
nnictl_config_dir
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
experiment_dict
[
args
.
id
][
'fileName'
]
)
shutil
.
copytree
(
nnictl_config_dir
,
os
.
path
.
join
(
temp_nnictl_dir
,
experiment_dict
[
args
.
id
][
'fileName'
]
))
nnictl_config_dir
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
args
.
id
)
shutil
.
copytree
(
nnictl_config_dir
,
os
.
path
.
join
(
temp_nnictl_dir
,
args
.
id
))
# Step3. Copy code dir
if
args
.
saveCodeDir
:
...
...
@@ -903,20 +928,20 @@ def load_experiment(args):
print_error
(
'Invalid: experiment id already exist!'
)
shutil
.
rmtree
(
temp_root_dir
)
exit
(
1
)
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
nnictl_temp_dir
,
experiment_
metadata
.
get
(
'fileName'
)
)):
if
not
os
.
path
.
exists
(
os
.
path
.
join
(
nnictl_temp_dir
,
experiment_
id
)):
print_error
(
'Invalid: experiment metadata does not exist!'
)
shutil
.
rmtree
(
temp_root_dir
)
exit
(
1
)
# Step2. Copy nnictl metadata
src_path
=
os
.
path
.
join
(
nnictl_temp_dir
,
experiment_
metadata
.
get
(
'fileName'
)
)
dest_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
experiment_
metadata
.
get
(
'fileName'
)
)
src_path
=
os
.
path
.
join
(
nnictl_temp_dir
,
experiment_
id
)
dest_path
=
os
.
path
.
join
(
NNICTL_HOME_DIR
,
experiment_
id
)
if
os
.
path
.
exists
(
dest_path
):
shutil
.
rmtree
(
dest_path
)
shutil
.
copytree
(
src_path
,
dest_path
)
# Step3. Copy experiment data
nni_config
=
Config
(
experiment_
metadata
.
get
(
'fileName'
)
)
nni_config
=
Config
(
experiment_
id
)
nnictl_exp_config
=
nni_config
.
get_config
(
'experimentConfig'
)
if
args
.
logDir
:
logDir
=
args
.
logDir
...
...
@@ -983,13 +1008,15 @@ def load_experiment(args):
experiment_config
.
add_experiment
(
experiment_id
,
experiment_metadata
.
get
(
'port'
),
experiment_metadata
.
get
(
'startTime'
),
experiment_metadata
.
get
(
'fileName'
),
experiment_metadata
.
get
(
'platform'
),
experiment_metadata
.
get
(
'experimentName'
),
experiment_metadata
.
get
(
'endTime'
),
experiment_metadata
.
get
(
'status'
))
experiment_metadata
.
get
(
'status'
),
experiment_metadata
.
get
(
'tag'
),
experiment_metadata
.
get
(
'pid'
),
experiment_metadata
.
get
(
'webUrl'
),
experiment_metadata
.
get
(
'logDir'
))
print_normal
(
'Load experiment %s succsss!'
%
experiment_id
)
# Step6. Cleanup temp data
shutil
.
rmtree
(
temp_root_dir
)
nni/tools/nnictl/tensorboard_utils.py
View file @
b40e3db7
...
...
@@ -10,8 +10,8 @@ from .rest_utils import rest_get, check_rest_server_quick, check_response
from
.config_utils
import
Config
,
Experiments
from
.url_utils
import
trial_jobs_url
,
get_local_urls
from
.constants
import
REST_TIME_OUT
from
.common_utils
import
print_normal
,
print_error
,
print_green
,
detect_process
,
detect_port
,
check_tensorboard_version
from
.nnictl_utils
import
check_experiment_id
,
check_experiment_id
from
.common_utils
import
print_normal
,
print_warning
,
print_error
,
print_green
,
detect_process
,
detect_port
,
check_tensorboard_version
from
.nnictl_utils
import
check_experiment_id
from
.ssh_utils
import
create_ssh_sftp_client
,
copy_remote_directory_to_local
def
parse_log_path
(
args
,
trial_content
):
...
...
@@ -19,7 +19,7 @@ def parse_log_path(args, trial_content):
path_list
=
[]
host_list
=
[]
for
trial
in
trial_content
:
if
args
.
trial_id
and
args
.
trial_id
!=
'all'
and
trial
.
get
(
'
i
d'
)
!=
args
.
trial_id
:
if
args
.
trial_id
and
args
.
trial_id
!=
'all'
and
trial
.
get
(
'
trialJobI
d'
)
!=
args
.
trial_id
:
continue
pattern
=
r
'(?P<head>.+)://(?P<host>.+):(?P<path>.*)'
match
=
re
.
search
(
pattern
,
trial
[
'logPath'
])
...
...
@@ -40,7 +40,7 @@ def copy_data_from_remote(args, nni_config, trial_content, path_list, host_list,
machine_dict
[
machine
[
'ip'
]]
=
{
'port'
:
machine
[
'port'
],
'passwd'
:
machine
[
'passwd'
],
'username'
:
machine
[
'username'
],
'sshKeyPath'
:
machine
.
get
(
'sshKeyPath'
),
'passphrase'
:
machine
.
get
(
'passphrase'
)}
for
index
,
host
in
enumerate
(
host_list
):
local_path
=
os
.
path
.
join
(
temp_nni_path
,
trial_content
[
index
].
get
(
'
i
d'
))
local_path
=
os
.
path
.
join
(
temp_nni_path
,
trial_content
[
index
].
get
(
'
trialJobI
d'
))
local_path_list
.
append
(
local_path
)
print_normal
(
'Copying log data from %s to %s'
%
(
host
+
':'
+
path_list
[
index
],
local_path
))
sftp
=
create_ssh_sftp_client
(
host
,
machine_dict
[
host
][
'port'
],
machine_dict
[
host
][
'username'
],
machine_dict
[
host
][
'passwd'
],
...
...
@@ -95,8 +95,7 @@ def stop_tensorboard(args):
experiment_id
=
check_experiment_id
(
args
)
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
config_file_name
=
experiment_dict
[
experiment_id
][
'fileName'
]
nni_config
=
Config
(
config_file_name
)
nni_config
=
Config
(
experiment_id
)
tensorboard_pid_list
=
nni_config
.
get_config
(
'tensorboardPidList'
)
if
tensorboard_pid_list
:
for
tensorboard_pid
in
tensorboard_pid_list
:
...
...
@@ -110,14 +109,36 @@ def stop_tensorboard(args):
else
:
print_error
(
'No tensorboard configuration!'
)
def
adl_tensorboard_helper
(
args
):
'''start tensorboard on adl'''
import
subprocess
if
args
.
trial_id
is
not
None
:
print_warning
(
'Tensorboard on adl platform will show all trials. No trial ids needed.'
)
cmd
=
"kubectl port-forward --address 0.0.0.0 deployment/{} {}:{}"
.
format
(
"adaptdl-tensorboard"
+
"-"
+
args
.
id
.
lower
(),
args
.
port
,
6006
)
print_green
(
'Tensorboard is accessible at 0.0.0.0:{port} or localhost:{port}'
.
format
(
port
=
args
.
port
))
subprocess
.
run
(
args
=
cmd
,
shell
=
True
)
def
start_tensorboard
(
args
):
'''start tensorboard'''
experiment_id
=
check_experiment_id
(
args
)
if
not
experiment_id
:
return
if
args
.
id
is
None
:
args
.
id
=
experiment_id
experiment_config
=
Experiments
()
experiment_dict
=
experiment_config
.
get_all_experiments
()
if
experiment_dict
[
args
.
id
][
"status"
]
==
"STOPPED"
:
print_error
(
"Experiment {} is stopped..."
.
format
(
args
.
id
))
return
config_file_name
=
experiment_dict
[
experiment_id
][
'fileName'
]
nni_config
=
Config
(
config_file_name
)
nni_config
=
Config
(
args
.
id
)
if
nni_config
.
get_config
(
'experimentConfig'
).
get
(
'trainingServicePlatform'
)
==
'adl'
:
adl_tensorboard_helper
(
args
)
return
rest_port
=
nni_config
.
get_config
(
'restServerPort'
)
rest_pid
=
nni_config
.
get_config
(
'restServerPid'
)
if
not
detect_process
(
rest_pid
):
...
...
@@ -144,4 +165,4 @@ def start_tensorboard(args):
os
.
makedirs
(
temp_nni_path
,
exist_ok
=
True
)
path_list
=
get_path_list
(
args
,
nni_config
,
trial_content
,
temp_nni_path
)
start_tensorboard_process
(
args
,
nni_config
,
path_list
,
temp_nni_path
)
start_tensorboard_process
(
args
,
nni_config
,
path_list
,
temp_nni_path
)
\ No newline at end of file
Prev
1
…
3
4
5
6
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment