Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e457047c
Unverified
Commit
e457047c
authored
Mar 09, 2021
by
QuanluZhang
Committed by
GitHub
Mar 09, 2021
Browse files
[retiarii] update debug info, and add license (#3438)
parent
539a7cd7
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
72 additions
and
13 deletions
+72
-13
nni/retiarii/integration.py
nni/retiarii/integration.py
+8
-5
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+3
-0
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+3
-0
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+3
-0
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+3
-0
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+3
-0
nni/retiarii/oneshot/__init__.py
nni/retiarii/oneshot/__init__.py
+3
-0
nni/retiarii/oneshot/interface.py
nni/retiarii/oneshot/interface.py
+3
-0
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+3
-0
nni/retiarii/operation.py
nni/retiarii/operation.py
+3
-0
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+3
-0
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+3
-1
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+3
-0
nni/retiarii/strategy/__init__.py
nni/retiarii/strategy/__init__.py
+3
-0
nni/retiarii/strategy/base.py
nni/retiarii/strategy/base.py
+3
-0
nni/retiarii/strategy/bruteforce.py
nni/retiarii/strategy/bruteforce.py
+7
-4
nni/retiarii/strategy/evolution.py
nni/retiarii/strategy/evolution.py
+5
-2
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-1
nni/retiarii/strategy/utils.py
nni/retiarii/strategy/utils.py
+3
-0
nni/retiarii/trial_entry.py
nni/retiarii/trial_entry.py
+3
-0
No files found.
nni/retiarii/integration.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
os
from
typing
import
Any
,
Callable
...
...
@@ -99,7 +102,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters'
:
parameters
,
'parameter_source'
:
'algorithm'
}
_logger
.
info
(
'New trial sent: %s'
,
new_trial
)
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
send
(
CommandType
.
NewTrialJob
,
json_dumps
(
new_trial
))
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
...
...
@@ -109,21 +112,21 @@ class RetiariiAdvisor(MsgDispatcherBase):
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs: %s'
,
num_trials
)
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
self
.
request_trial_jobs_callback
(
num_trials
)
# pylint: disable=not-callable
def
handle_update_search_space
(
self
,
data
):
_logger
.
info
(
'Received search space: %s'
,
data
)
_logger
.
debug
(
'Received search space: %s'
,
data
)
self
.
search_space
=
data
def
handle_trial_end
(
self
,
data
):
_logger
.
info
(
'Trial end: %s'
,
data
)
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
trial_end_callback
(
json_loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
_logger
.
info
(
'Metric reported: %s'
,
data
)
_logger
.
debug
(
'Metric reported: %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
nni/retiarii/integration_api.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
from
typing
import
NewType
,
Any
...
...
nni/retiarii/mutator.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
)
from
.graph
import
Model
...
...
nni/retiarii/nn/pytorch/api.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
collections
import
OrderedDict
from
typing
import
Any
,
List
,
Union
,
Dict
import
warnings
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
...mutator
import
Mutator
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
...
...
nni/retiarii/oneshot/__init__.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.interface
import
BaseOneShotTrainer
nni/retiarii/oneshot/interface.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
from
typing
import
Any
...
...
nni/retiarii/oneshot/pytorch/__init__.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.darts
import
DartsTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
...
...
nni/retiarii/operation.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Dict
,
List
)
from
.
import
debug_configs
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
..operation
import
TensorFlowOperation
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
List
)
import
torch
...
...
@@ -369,7 +372,6 @@ class TensorOps(PyTorchOperation):
return
TensorOpExceptions
[
self
.
type
](
output
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
+
1
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
print
(
args_str
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
args_str
}
)'
class
TorchOps
(
PyTorchOperation
):
...
...
nni/retiarii/serializer.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
import
functools
import
inspect
...
...
nni/retiarii/strategy/__init__.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.base
import
BaseStrategy
from
.bruteforce
import
Random
,
GridSearch
from
.evolution
import
RegularizedEvolution
...
...
nni/retiarii/strategy/base.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
from
typing
import
List
...
...
nni/retiarii/strategy/bruteforce.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
itertools
import
logging
...
...
@@ -36,7 +39,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history
.
add
(
selected
)
break
if
retry_count
+
1
==
retries
:
_logger
.
info
(
'Random generation has run out of patience. There is nothing to search. Exiting.'
)
_logger
.
debug
(
'Random generation has run out of patience. There is nothing to search. Exiting.'
)
return
yield
{
key
:
value
for
key
,
value
in
zip
(
keys
,
selected
)}
...
...
@@ -58,7 +61,7 @@ class GridSearch(BaseStrategy):
def
run
(
self
,
base_model
,
applied_mutators
):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
grid_generator
(
search_space
,
shuffle
=
self
.
shuffle
):
_logger
.
info
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
if
query_available_resources
()
<=
0
:
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
...
...
@@ -101,7 +104,7 @@ class Random(BaseStrategy):
model
=
base_model
for
mutator
in
applied_mutators
:
model
=
mutator
.
apply
(
model
)
_logger
.
info
(
'New model created. Applied mutators are: %s'
,
str
(
applied_mutators
))
_logger
.
debug
(
'New model created. Applied mutators are: %s'
,
str
(
applied_mutators
))
submit_models
(
model
)
else
:
time
.
sleep
(
self
.
_polling_interval
)
...
...
@@ -109,7 +112,7 @@ class Random(BaseStrategy):
_logger
.
info
(
'Random search running in fixed size mode. Dedup: %s.'
,
'on'
if
self
.
dedup
else
'off'
)
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
random_generator
(
search_space
,
dedup
=
self
.
dedup
):
_logger
.
info
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
if
query_available_resources
()
<=
0
:
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
nni/retiarii/strategy/evolution.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
dataclasses
import
logging
...
...
@@ -122,7 +125,7 @@ class RegularizedEvolution(BaseStrategy):
break
def
_submit_config
(
self
,
config
,
base_model
,
mutators
):
_logger
.
info
(
'Model submitted to running queue: %s'
,
config
)
_logger
.
debug
(
'Model submitted to running queue: %s'
,
config
)
model
=
get_targeted_model
(
base_model
,
mutators
,
config
)
submit_models
(
model
)
self
.
_running_models
.
append
((
config
,
model
))
...
...
@@ -138,7 +141,7 @@ class RegularizedEvolution(BaseStrategy):
metric
=
model
.
metric
if
metric
is
not
None
:
individual
=
Individual
(
config
,
metric
)
_logger
.
info
(
'Individual created: %s'
,
str
(
individual
))
_logger
.
debug
(
'Individual created: %s'
,
str
(
individual
))
self
.
_population
.
append
(
individual
)
if
len
(
self
.
_population
)
>
self
.
population_size
:
self
.
_population
.
popleft
()
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
time
...
...
@@ -55,7 +58,7 @@ class TPEStrategy(BaseStrategy):
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
model
=
base_model
_logger
.
info
(
'New model created. Applied mutators: %s'
,
str
(
applied_mutators
))
_logger
.
debug
(
'New model created. Applied mutators: %s'
,
str
(
applied_mutators
))
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
...
...
nni/retiarii/strategy/utils.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
from
typing
import
Dict
,
Any
,
List
from
..graph
import
Model
...
...
nni/retiarii/trial_entry.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Entrypoint for trials.
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment