Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
d165905d
Unverified
Commit
d165905d
authored
Dec 11, 2020
by
QuanluZhang
Committed by
GitHub
Dec 11, 2020
Browse files
[Retiarii] end2end (#3122)
parent
7d1acfbd
Changes
54
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
361 additions
and
116 deletions
+361
-116
nni/retiarii/graph.py
nni/retiarii/graph.py
+14
-2
nni/retiarii/integration.py
nni/retiarii/integration.py
+2
-5
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+2
-2
nni/retiarii/nn/__init__.py
nni/retiarii/nn/__init__.py
+0
-0
nni/retiarii/nn/pytorch/__init__.py
nni/retiarii/nn/pytorch/__init__.py
+1
-0
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+56
-7
nni/retiarii/operation.py
nni/retiarii/operation.py
+25
-0
nni/retiarii/strategies/__init__.py
nni/retiarii/strategies/__init__.py
+1
-0
nni/retiarii/strategies/strategy.py
nni/retiarii/strategies/strategy.py
+8
-0
nni/retiarii/strategies/tpe_strategy.py
nni/retiarii/strategies/tpe_strategy.py
+74
-0
nni/retiarii/trainer/interface.py
nni/retiarii/trainer/interface.py
+19
-0
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+3
-0
nni/retiarii/utils.py
nni/retiarii/utils.py
+18
-0
nni/runtime/common.py
nni/runtime/common.py
+0
-89
nni/runtime/log.py
nni/runtime/log.py
+134
-0
nni/runtime/msg_dispatcher_base.py
nni/runtime/msg_dispatcher_base.py
+0
-2
nni/runtime/platform/local.py
nni/runtime/platform/local.py
+0
-4
nni/runtime/platform/standalone.py
nni/runtime/platform/standalone.py
+0
-3
nni/runtime/protocol.py
nni/runtime/protocol.py
+1
-2
nni/tools/nnictl/launcher.py
nni/tools/nnictl/launcher.py
+3
-0
No files found.
nni/retiarii/graph.py
View file @
d165905d
...
...
@@ -339,13 +339,21 @@ class Graph:
while
curr_nodes
:
curr_node
=
curr_nodes
.
pop
(
0
)
sorted_nodes
.
append
(
curr_node
)
for
successor
in
curr_node
.
successors
:
# use successor_slots because a node may connect to another node multiple times
# to different slots
for
successor_slot
in
curr_node
.
successor_slots
:
successor
=
successor_slot
[
0
]
node_to_fanin
[
successor
]
-=
1
if
node_to_fanin
[
successor
]
==
0
:
curr_nodes
.
append
(
successor
)
for
key
in
node_to_fanin
:
assert
node_to_fanin
[
key
]
==
0
assert
node_to_fanin
[
key
]
==
0
,
'{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'
.
format
(
key
,
node_to_fanin
[
key
],
key
.
predecessors
[
0
],
self
.
edges
,
node_to_fanin
.
values
(),
node_to_fanin
.
keys
())
return
sorted_nodes
...
...
@@ -485,6 +493,10 @@ class Node:
def
successors
(
self
)
->
List
[
'Node'
]:
return
sorted
(
set
(
edge
.
tail
for
edge
in
self
.
outgoing_edges
),
key
=
(
lambda
node
:
node
.
id
))
@
property
def
successor_slots
(
self
)
->
List
[
Tuple
[
'Node'
,
Union
[
int
,
None
]]]:
return
set
((
edge
.
tail
,
edge
.
tail_slot
)
for
edge
in
self
.
outgoing_edges
)
@
property
def
incoming_edges
(
self
)
->
List
[
'Edge'
]:
return
[
edge
for
edge
in
self
.
graph
.
edges
if
edge
.
tail
is
self
]
...
...
nni/retiarii/integration.py
View file @
d165905d
...
...
@@ -44,7 +44,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
"""
def
__init__
(
self
,
strategy
:
Union
[
str
,
Callable
]
):
def
__init__
(
self
):
super
(
RetiariiAdvisor
,
self
).
__init__
()
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
self
.
search_space
=
None
...
...
@@ -55,11 +55,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
self
.
intermediate_metric_callback
:
Callable
[[
int
,
MetricData
],
None
]
=
None
self
.
final_metric_callback
:
Callable
[[
int
,
MetricData
],
None
]
=
None
self
.
strategy
=
utils
.
import_
(
strategy
)
if
isinstance
(
strategy
,
str
)
else
strategy
self
.
parameters_count
=
0
_logger
.
info
(
'Starting strategy...'
)
threading
.
Thread
(
target
=
self
.
strategy
).
start
()
_logger
.
info
(
'Strategy started!'
)
def
handle_initialize
(
self
,
data
):
"""callback for initializing the advisor
...
...
@@ -125,6 +121,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
value
=
json_tricks
.
loads
(
value
)
if
isinstance
(
value
,
dict
):
if
'default'
in
value
:
return
value
[
'default'
]
...
...
nni/retiarii/mutator.py
View file @
d165905d
...
...
@@ -73,9 +73,9 @@ class Mutator:
sampler_backup
=
self
.
sampler
recorder
=
_RecorderSampler
()
self
.
sampler
=
recorder
self
.
apply
(
model
)
new_model
=
self
.
apply
(
model
)
self
.
sampler
=
sampler_backup
return
recorder
.
recorded_candidates
return
recorder
.
recorded_candidates
,
new_model
def
mutate
(
self
,
model
:
Model
)
->
None
:
...
...
nni/retiarii/
model_apis
/__init__.py
→
nni/retiarii/
nn
/__init__.py
View file @
d165905d
File moved
nni/retiarii/nn/pytorch/__init__.py
0 → 100644
View file @
d165905d
from
.nn
import
*
nni/retiarii/
model_apis
/nn.py
→
nni/retiarii/
nn/pytorch
/nn.py
View file @
d165905d
import
inspect
import
logging
import
torch
import
torch.nn
as
nn
from
typing
import
(
Any
,
Tuple
,
List
,
Optional
)
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
#consoleHandler = logging.StreamHandler()
#consoleHandler.setLevel(logging.INFO)
#_logger.addHandler(consoleHandler)
_records
=
None
def
enable_record_args
():
...
...
@@ -26,6 +23,42 @@ def get_records():
global
_records
return
_records
def
add_record
(
name
,
value
):
global
_records
if
_records
is
not
None
:
assert
name
not
in
_records
,
'{} already in _records'
.
format
(
name
)
_records
[
name
]
=
value
class
LayerChoice
(
nn
.
Module
):
def
__init__
(
self
,
candidate_ops
:
List
,
label
:
str
=
None
):
super
(
LayerChoice
,
self
).
__init__
()
self
.
candidate_ops
=
candidate_ops
self
.
label
=
label
def
forward
(
self
,
x
):
return
x
class
InputChoice
(
nn
.
Module
):
def
__init__
(
self
,
n_chosen
:
int
=
1
,
reduction
:
str
=
'sum'
,
label
:
str
=
None
):
super
(
InputChoice
,
self
).
__init__
()
self
.
n_chosen
=
n_chosen
self
.
reduction
=
reduction
self
.
label
=
label
def
forward
(
self
,
candidate_inputs
:
List
[
'Tensor'
])
->
'Tensor'
:
# fake return
return
torch
.
tensor
(
candidate_inputs
)
class
ValueChoice
:
"""
The instance of this class can only be used as input argument,
when instantiating a pytorch module.
TODO: can also be used in training approach
"""
def
__init__
(
self
,
candidate_values
:
List
[
Any
]):
self
.
candidate_values
=
candidate_values
class
Placeholder
(
nn
.
Module
):
def
__init__
(
self
,
label
,
related_info
):
...
...
@@ -45,8 +78,13 @@ class Module(nn.Module):
# TODO: users have to pass init's arguments to super init's arguments
global
_records
if
_records
is
not
None
:
# TODO: change tuple to dict
_records
[
id
(
self
)]
=
(
args
,
kwargs
)
assert
not
kwargs
argname_list
=
list
(
inspect
.
signature
(
self
.
__class__
).
parameters
.
keys
())
assert
len
(
argname_list
)
==
len
(
args
),
'Error: {} not put input arguments in its super().__init__ function'
.
format
(
self
.
__class__
)
full_args
=
{}
for
i
,
arg_value
in
enumerate
(
args
):
full_args
[
argname_list
[
i
]]
=
args
[
i
]
_records
[
id
(
self
)]
=
full_args
#print('my module: ', id(self), args, kwargs)
super
(
Module
,
self
).
__init__
()
...
...
@@ -57,6 +95,13 @@ class Sequential(nn.Sequential):
_records
[
id
(
self
)]
=
{}
# no args need to be recorded
super
(
Sequential
,
self
).
__init__
(
*
args
)
class
ModuleList
(
nn
.
ModuleList
):
def
__init__
(
self
,
*
args
):
global
_records
if
_records
is
not
None
:
_records
[
id
(
self
)]
=
{}
# no args need to be recorded
super
(
ModuleList
,
self
).
__init__
(
*
args
)
def
wrap_module
(
original_class
):
orig_init
=
original_class
.
__init__
argname_list
=
list
(
inspect
.
signature
(
original_class
).
parameters
.
keys
())
...
...
@@ -80,3 +125,7 @@ BatchNorm2d = wrap_module(nn.BatchNorm2d)
ReLU
=
wrap_module
(
nn
.
ReLU
)
Dropout
=
wrap_module
(
nn
.
Dropout
)
Linear
=
wrap_module
(
nn
.
Linear
)
MaxPool2d
=
wrap_module
(
nn
.
MaxPool2d
)
AvgPool2d
=
wrap_module
(
nn
.
AvgPool2d
)
Identity
=
wrap_module
(
nn
.
Identity
)
AdaptiveAvgPool2d
=
wrap_module
(
nn
.
AdaptiveAvgPool2d
)
nni/retiarii/operation.py
View file @
d165905d
...
...
@@ -105,6 +105,7 @@ class PyTorchOperation(Operation):
return
None
def
to_forward_code
(
self
,
field
:
str
,
output
:
str
,
inputs
:
List
[
str
])
->
str
:
from
.converter.op_types
import
Type
if
self
.
_to_class_name
()
is
not
None
:
return
f
'
{
output
}
= self.
{
field
}
(
{
", "
.
join
(
inputs
)
}
)'
elif
self
.
type
.
startswith
(
'Function.'
):
...
...
@@ -120,10 +121,34 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::__getitem__'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
inputs
[
1
]
}
]'
elif
self
.
type
==
'aten::append'
:
assert
len
(
inputs
)
==
2
return
f
'_,
{
output
}
=
{
inputs
[
0
]
}
.append(
{
inputs
[
1
]
}
),
{
inputs
[
0
]
}
'
elif
self
.
type
==
'aten::cat'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
= torch.cat(
{
inputs
[
0
]
}
, dim=
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::add'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
+
{
inputs
[
1
]
}
'
elif
self
.
type
==
Type
.
MergedSlice
:
assert
(
len
(
inputs
)
-
1
)
%
4
==
0
slices
=
[]
dim
=
int
((
len
(
inputs
)
-
1
)
/
4
)
for
i
in
range
(
dim
):
slices
.
append
(
f
'
{
inputs
[
i
*
4
+
2
]
}
:
{
inputs
[
i
*
4
+
3
]
}
:
{
inputs
[
i
*
4
+
4
]
}
'
)
slice_str
=
','
.
join
(
slices
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
[
{
slice_str
}
]'
elif
self
.
type
==
'aten::size'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.size(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::view'
:
assert
len
(
inputs
)
==
2
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::slice'
:
raise
RuntimeError
(
'not supposed to have aten::slice operation'
)
else
:
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
...
...
nni/retiarii/strategies/__init__.py
0 → 100644
View file @
d165905d
from
.tpe_strategy
import
TPEStrategy
nni/retiarii/strategies/strategy.py
0 → 100644
View file @
d165905d
import
abc
from
typing
import
List
class
BaseStrategy
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
run
(
self
,
base_model
:
'Model'
,
applied_mutators
:
List
[
'Mutator'
],
trainer
:
'BaseTrainer'
)
->
None
:
pass
nni/retiarii/strategies/tpe_strategy.py
0 → 100644
View file @
d165905d
import
json
import
logging
import
random
import
os
from
..
import
Model
,
submit_models
,
wait_models
from
..
import
Sampler
from
.strategy
import
BaseStrategy
from
...algorithms.hpo.hyperopt_tuner.hyperopt_tuner
import
HyperoptTuner
_logger
=
logging
.
getLogger
(
__name__
)
class
TPESampler
(
Sampler
):
def
__init__
(
self
,
optimize_mode
=
'minimize'
):
self
.
tpe_tuner
=
HyperoptTuner
(
'tpe'
,
optimize_mode
)
self
.
cur_sample
=
None
self
.
index
=
None
self
.
total_parameters
=
{}
def
update_sample_space
(
self
,
sample_space
):
search_space
=
{}
for
i
,
each
in
enumerate
(
sample_space
):
search_space
[
str
(
i
)]
=
{
'_type'
:
'choice'
,
'_value'
:
each
}
self
.
tpe_tuner
.
update_search_space
(
search_space
)
def
generate_samples
(
self
,
model_id
):
self
.
cur_sample
=
self
.
tpe_tuner
.
generate_parameters
(
model_id
)
self
.
total_parameters
[
model_id
]
=
self
.
cur_sample
self
.
index
=
0
def
receive_result
(
self
,
model_id
,
result
):
self
.
tpe_tuner
.
receive_trial_result
(
model_id
,
self
.
total_parameters
[
model_id
],
result
)
def
choice
(
self
,
candidates
,
mutator
,
model
,
index
):
chosen
=
self
.
cur_sample
[
str
(
self
.
index
)]
self
.
index
+=
1
return
chosen
class
TPEStrategy
(
BaseStrategy
):
def
__init__
(
self
):
self
.
tpe_sampler
=
TPESampler
()
self
.
model_id
=
0
def
run
(
self
,
base_model
,
applied_mutators
,
trainer
):
sample_space
=
[]
new_model
=
base_model
for
mutator
in
applied_mutators
:
recorded_candidates
,
new_model
=
mutator
.
dry_run
(
new_model
)
sample_space
.
extend
(
recorded_candidates
)
self
.
tpe_sampler
.
update_sample_space
(
sample_space
)
try
:
_logger
.
info
(
'stargety start...'
)
while
True
:
model
=
base_model
_logger
.
info
(
'apply mutators...'
)
_logger
.
info
(
'mutators: {}'
.
format
(
applied_mutators
))
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
for
mutator
in
applied_mutators
:
_logger
.
info
(
'mutate model...'
)
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
model
=
mutator
.
apply
(
model
)
# get and apply training approach
_logger
.
info
(
'apply training approach...'
)
model
.
apply_trainer
(
trainer
[
'modulename'
],
trainer
[
'args'
])
# run models
submit_models
(
model
)
wait_models
(
model
)
self
.
tpe_sampler
.
receive_result
(
self
.
model_id
,
model
.
metric
)
self
.
model_id
+=
1
_logger
.
info
(
'Strategy says:'
,
model
.
metric
)
except
Exception
as
e
:
_logger
.
error
(
logging
.
exception
(
'message'
))
nni/retiarii/trainer/interface.py
View file @
d165905d
import
abc
import
inspect
from
..nn.pytorch
import
add_record
from
typing
import
*
...
...
@@ -17,6 +19,23 @@ class BaseTrainer(abc.ABC):
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
module
=
self
.
__class__
.
__module__
if
module
is
None
or
module
==
str
.
__class__
.
__module__
:
full_class_name
=
self
.
__class__
.
__name__
else
:
full_class_name
=
module
+
'.'
+
self
.
__class__
.
__name__
assert
not
kwargs
argname_list
=
list
(
inspect
.
signature
(
self
.
__class__
).
parameters
.
keys
())
assert
len
(
argname_list
)
==
len
(
args
),
'Error: {} not put input arguments in its super().__init__ function'
.
format
(
self
.
__class__
)
full_args
=
{}
for
i
,
arg_value
in
enumerate
(
args
):
if
argname_list
[
i
]
==
'model'
:
assert
i
==
0
continue
full_args
[
argname_list
[
i
]]
=
args
[
i
]
add_record
(
id
(
self
),
{
'modulename'
:
full_class_name
,
'args'
:
full_args
})
@
abc
.
abstractmethod
def
fit
(
self
)
->
None
:
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
d165905d
...
...
@@ -78,6 +78,9 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super
(
PyTorchImageClassificationTrainer
,
self
).
__init__
(
model
,
dataset_cls
,
dataset_kwargs
,
dataloader_kwargs
,
optimizer_cls
,
optimizer_kwargs
,
trainer_kwargs
)
self
.
_use_cuda
=
torch
.
cuda
.
is_available
()
self
.
model
=
model
if
self
.
_use_cuda
:
...
...
nni/retiarii/utils.py
View file @
d165905d
import
traceback
from
.nn.pytorch
import
enable_record_args
,
get_records
,
disable_record_args
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'Any'
:
if
target
is
None
:
return
None
path
,
identifier
=
target
.
rsplit
(
'.'
,
1
)
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
class
TraceClassArguments
:
def
__init__
(
self
):
self
.
recorded_arguments
=
None
def
__enter__
(
self
):
enable_record_args
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_value
,
tb
):
if
exc_type
is
not
None
:
traceback
.
print_exception
(
exc_type
,
exc_value
,
tb
)
# return False # uncomment to pass exception through
self
.
recorded_arguments
=
get_records
()
disable_record_args
()
nni/runtime/common.py
View file @
d165905d
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
datetime
import
datetime
from
io
import
TextIOBase
import
logging
import
os
import
sys
import
time
log_level_map
=
{
'fatal'
:
logging
.
FATAL
,
'error'
:
logging
.
ERROR
,
'warning'
:
logging
.
WARNING
,
'info'
:
logging
.
INFO
,
'debug'
:
logging
.
DEBUG
}
_time_format
=
'%m/%d/%Y, %I:%M:%S %p'
# FIXME
# This hotfix the bug that querying installed tuners with `package_utils` will activate dispatcher logger.
# This behavior depends on underlying implementation of `nnictl` and is likely to break in future.
_logger_initialized
=
False
class
_LoggerFileWrapper
(
TextIOBase
):
def
__init__
(
self
,
logger_file
):
self
.
file
=
logger_file
def
write
(
self
,
s
):
if
s
!=
'
\n
'
:
cur_time
=
datetime
.
now
().
strftime
(
_time_format
)
self
.
file
.
write
(
'[{}] PRINT '
.
format
(
cur_time
)
+
s
+
'
\n
'
)
self
.
file
.
flush
()
return
len
(
s
)
def
init_logger
(
logger_file_path
,
log_level_name
=
'info'
):
"""Initialize root logger.
This will redirect anything from logging.getLogger() as well as stdout to specified file.
logger_file_path: path of logger file (path-like object).
"""
global
_logger_initialized
if
_logger_initialized
:
return
_logger_initialized
=
True
if
os
.
environ
.
get
(
'NNI_PLATFORM'
)
==
'unittest'
:
return
# fixme: launching logic needs refactor
log_level
=
log_level_map
.
get
(
log_level_name
,
logging
.
INFO
)
logger_file
=
open
(
logger_file_path
,
'w'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
_time_format
)
handler
=
logging
.
StreamHandler
(
logger_file
)
handler
.
setFormatter
(
formatter
)
root_logger
=
logging
.
getLogger
()
root_logger
.
addHandler
(
handler
)
root_logger
.
setLevel
(
log_level
)
# these modules are too verbose
logging
.
getLogger
(
'matplotlib'
).
setLevel
(
log_level
)
sys
.
stdout
=
_LoggerFileWrapper
(
logger_file
)
def
init_standalone_logger
():
"""
Initialize root logger for standalone mode.
This will set NNI's log level to INFO and print its log to stdout.
"""
global
_logger_initialized
if
_logger_initialized
:
return
_logger_initialized
=
True
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s) %(message)s'
formatter
=
logging
.
Formatter
(
fmt
,
_time_format
)
handler
=
logging
.
StreamHandler
(
sys
.
stdout
)
handler
.
setFormatter
(
formatter
)
nni_logger
=
logging
.
getLogger
(
'nni'
)
nni_logger
.
addHandler
(
handler
)
nni_logger
.
setLevel
(
logging
.
INFO
)
nni_logger
.
propagate
=
False
# Following line does not affect NNI loggers, but without this user's logger won't be able to
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info` instead of
# `logging.getLogger('xxx')` in all examples.
logging
.
basicConfig
()
_multi_thread
=
False
_multi_phase
=
False
...
...
nni/runtime/log.py
0 → 100644
View file @
d165905d
from
datetime
import
datetime
from
io
import
TextIOBase
import
logging
from
logging
import
FileHandler
,
Formatter
,
Handler
,
StreamHandler
from
pathlib
import
Path
import
sys
from
typing
import
Optional
from
.env_vars
import
dispatcher_env_vars
,
trial_env_vars
def
init_logger
()
->
None
:
"""
This function will (and should only) get invoked on the first time of importing nni (no matter which submodule).
It will try to detect the running environment and setup logger accordingly.
The detection should work in most cases but for `nnictl` and `nni.experiment`.
They will be identified as "standalone" mode and must configure the logger by themselves.
"""
if
dispatcher_env_vars
.
SDK_PROCESS
==
'dispatcher'
:
_init_logger_dispatcher
()
return
trial_platform
=
trial_env_vars
.
NNI_PLATFORM
if
trial_platform
==
'unittest'
:
return
if
trial_platform
:
_init_logger_trial
()
return
_init_logger_standalone
()
time_format
=
'%Y-%m-%d %H:%M:%S'
formatter
=
Formatter
(
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
,
time_format
)
def
_init_logger_dispatcher
()
->
None
:
log_level_map
=
{
'fatal'
:
logging
.
CRITICAL
,
'error'
:
logging
.
ERROR
,
'warning'
:
logging
.
WARNING
,
'info'
:
logging
.
INFO
,
'debug'
:
logging
.
DEBUG
}
log_path
=
_prepare_log_dir
(
dispatcher_env_vars
.
NNI_LOG_DIRECTORY
)
/
'dispatcher.log'
log_level
=
log_level_map
.
get
(
dispatcher_env_vars
.
NNI_LOG_LEVEL
,
logging
.
INFO
)
_setup_root_logger
(
FileHandler
(
log_path
),
log_level
)
def
_init_logger_trial
()
->
None
:
log_path
=
_prepare_log_dir
(
trial_env_vars
.
NNI_OUTPUT_DIR
)
/
'trial.log'
log_file
=
open
(
log_path
,
'w'
)
_setup_root_logger
(
StreamHandler
(
log_file
),
logging
.
INFO
)
sys
.
stdout
=
_LogFileWrapper
(
log_file
)
def
_init_logger_standalone
()
->
None
:
_setup_nni_logger
(
StreamHandler
(
sys
.
stdout
),
logging
.
INFO
)
# Following line does not affect NNI loggers, but without this user's logger won't
# print log even it's level is set to INFO, so we do it for user's convenience.
# If this causes any issue in future, remove it and use `logging.info()` instead of
# `logging.getLogger('xxx').info()` in all examples.
logging
.
basicConfig
()
def
_prepare_log_dir
(
path
:
Optional
[
str
])
->
Path
:
if
path
is
None
:
return
Path
()
ret
=
Path
(
path
)
ret
.
mkdir
(
parents
=
True
,
exist_ok
=
True
)
return
ret
def
_setup_root_logger
(
handler
:
Handler
,
level
:
int
)
->
None
:
_setup_logger
(
''
,
handler
,
level
)
def
_setup_nni_logger
(
handler
:
Handler
,
level
:
int
)
->
None
:
_setup_logger
(
'nni'
,
handler
,
level
)
def
_setup_logger
(
name
:
str
,
handler
:
Handler
,
level
:
int
)
->
None
:
handler
.
setFormatter
(
formatter
)
logger
=
logging
.
getLogger
(
name
)
logger
.
addHandler
(
handler
)
logger
.
setLevel
(
level
)
logger
.
propagate
=
False
class
_LogFileWrapper
(
TextIOBase
):
# wrap the logger file so that anything written to it will automatically get formatted
def
__init__
(
self
,
log_file
:
TextIOBase
):
self
.
file
:
TextIOBase
=
log_file
self
.
line_buffer
:
Optional
[
str
]
=
None
self
.
line_start_time
:
Optional
[
datetime
]
=
None
def
write
(
self
,
s
:
str
)
->
int
:
cur_time
=
datetime
.
now
()
if
self
.
line_buffer
and
(
cur_time
-
self
.
line_start_time
).
total_seconds
()
>
0.1
:
self
.
flush
()
if
self
.
line_buffer
:
self
.
line_buffer
+=
s
else
:
self
.
line_buffer
=
s
self
.
line_start_time
=
cur_time
if
'
\n
'
not
in
s
:
return
len
(
s
)
time_str
=
cur_time
.
strftime
(
time_format
)
lines
=
self
.
line_buffer
.
split
(
'
\n
'
)
for
line
in
lines
[:
-
1
]:
self
.
file
.
write
(
f
'[
{
time_str
}
] PRINT
{
line
}
\n
'
)
self
.
file
.
flush
()
self
.
line_buffer
=
lines
[
-
1
]
self
.
line_start_time
=
cur_time
return
len
(
s
)
def
flush
(
self
)
->
None
:
if
self
.
line_buffer
:
time_str
=
self
.
line_start_time
.
strftime
(
time_format
)
self
.
file
.
write
(
f
'[
{
time_str
}
] PRINT
{
self
.
line_buffer
}
\n
'
)
self
.
file
.
flush
()
self
.
line_buffer
=
None
nni/runtime/msg_dispatcher_base.py
View file @
d165905d
...
...
@@ -9,11 +9,9 @@ import json_tricks
from
.common
import
multi_thread_enabled
from
.env_vars
import
dispatcher_env_vars
from
..utils
import
init_dispatcher_logger
from
..recoverable
import
Recoverable
from
.protocol
import
CommandType
,
receive
init_dispatcher_logger
()
_logger
=
logging
.
getLogger
(
__name__
)
...
...
nni/runtime/platform/local.py
View file @
d165905d
...
...
@@ -7,7 +7,6 @@ import json
import
time
import
subprocess
from
..common
import
init_logger
from
..env_vars
import
trial_env_vars
from
nni.utils
import
to_json
...
...
@@ -21,9 +20,6 @@ if not os.path.exists(_outputdir):
os
.
makedirs
(
_outputdir
)
_nni_platform
=
trial_env_vars
.
NNI_PLATFORM
if
_nni_platform
==
'local'
:
_log_file_path
=
os
.
path
.
join
(
_outputdir
,
'trial.log'
)
init_logger
(
_log_file_path
)
_multiphase
=
trial_env_vars
.
MULTI_PHASE
...
...
nni/runtime/platform/standalone.py
View file @
d165905d
...
...
@@ -4,8 +4,6 @@
import
logging
import
json_tricks
from
..common
import
init_standalone_logger
__all__
=
[
'get_next_parameter'
,
'get_experiment_id'
,
...
...
@@ -14,7 +12,6 @@ __all__ = [
'send_metric'
,
]
init_standalone_logger
()
_logger
=
logging
.
getLogger
(
'nni'
)
...
...
nni/runtime/protocol.py
View file @
d165905d
...
...
@@ -32,8 +32,7 @@ try:
_in_file
=
open
(
3
,
'rb'
)
_out_file
=
open
(
4
,
'wb'
)
except
OSError
:
_msg
=
'IPC pipeline not exists, maybe you are importing tuner/assessor from trial code?'
logging
.
getLogger
(
__name__
).
warning
(
_msg
)
pass
def
send
(
command
,
data
):
...
...
nni/tools/nnictl/launcher.py
View file @
d165905d
...
...
@@ -85,6 +85,7 @@ def start_rest_server(port, platform, mode, experiment_id, foreground=False, log
log_header
=
LOG_HEADER
%
str
(
time_now
)
stdout_file
.
write
(
log_header
)
stderr_file
.
write
(
log_header
)
print
(
'## [nnictl] cmds:'
,
cmds
)
if
sys
.
platform
==
'win32'
:
from
subprocess
import
CREATE_NEW_PROCESS_GROUP
if
foreground
:
...
...
@@ -387,6 +388,8 @@ def set_experiment(experiment_config, mode, port, config_file_name):
{
'key'
:
'aml_config'
,
'value'
:
experiment_config
[
'amlConfig'
]})
request_data
[
'clusterMetaData'
].
append
(
{
'key'
:
'trial_config'
,
'value'
:
experiment_config
[
'trial'
]})
print
(
'## experiment config:'
)
print
(
request_data
)
response
=
rest_post
(
experiment_url
(
port
),
json
.
dumps
(
request_data
),
REST_TIME_OUT
,
show_error
=
True
)
if
check_response
(
response
):
return
response
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment