Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
e457047c
Unverified
Commit
e457047c
authored
Mar 09, 2021
by
QuanluZhang
Committed by
GitHub
Mar 09, 2021
Browse files
[retiarii] update debug info, and add license (#3438)
parent
539a7cd7
Changes
42
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
72 additions
and
13 deletions
+72
-13
nni/retiarii/integration.py
nni/retiarii/integration.py
+8
-5
nni/retiarii/integration_api.py
nni/retiarii/integration_api.py
+3
-0
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+3
-0
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+3
-0
nni/retiarii/nn/pytorch/mutator.py
nni/retiarii/nn/pytorch/mutator.py
+3
-0
nni/retiarii/nn/pytorch/nn.py
nni/retiarii/nn/pytorch/nn.py
+3
-0
nni/retiarii/oneshot/__init__.py
nni/retiarii/oneshot/__init__.py
+3
-0
nni/retiarii/oneshot/interface.py
nni/retiarii/oneshot/interface.py
+3
-0
nni/retiarii/oneshot/pytorch/__init__.py
nni/retiarii/oneshot/pytorch/__init__.py
+3
-0
nni/retiarii/operation.py
nni/retiarii/operation.py
+3
-0
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+3
-0
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+3
-1
nni/retiarii/serializer.py
nni/retiarii/serializer.py
+3
-0
nni/retiarii/strategy/__init__.py
nni/retiarii/strategy/__init__.py
+3
-0
nni/retiarii/strategy/base.py
nni/retiarii/strategy/base.py
+3
-0
nni/retiarii/strategy/bruteforce.py
nni/retiarii/strategy/bruteforce.py
+7
-4
nni/retiarii/strategy/evolution.py
nni/retiarii/strategy/evolution.py
+5
-2
nni/retiarii/strategy/tpe_strategy.py
nni/retiarii/strategy/tpe_strategy.py
+4
-1
nni/retiarii/strategy/utils.py
nni/retiarii/strategy/utils.py
+3
-0
nni/retiarii/trial_entry.py
nni/retiarii/trial_entry.py
+3
-0
No files found.
nni/retiarii/integration.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
logging
import
os
import
os
from
typing
import
Any
,
Callable
from
typing
import
Any
,
Callable
...
@@ -99,7 +102,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -99,7 +102,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters'
:
parameters
,
'parameters'
:
parameters
,
'parameter_source'
:
'algorithm'
'parameter_source'
:
'algorithm'
}
}
_logger
.
info
(
'New trial sent: %s'
,
new_trial
)
_logger
.
debug
(
'New trial sent: %s'
,
new_trial
)
send
(
CommandType
.
NewTrialJob
,
json_dumps
(
new_trial
))
send
(
CommandType
.
NewTrialJob
,
json_dumps
(
new_trial
))
if
self
.
send_trial_callback
is
not
None
:
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
...
@@ -109,21 +112,21 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -109,21 +112,21 @@ class RetiariiAdvisor(MsgDispatcherBase):
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
send
(
CommandType
.
NoMoreTrialJobs
,
''
)
def
handle_request_trial_jobs
(
self
,
num_trials
):
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs: %s'
,
num_trials
)
_logger
.
debug
(
'Request trial jobs: %s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
if
self
.
request_trial_jobs_callback
is
not
None
:
self
.
request_trial_jobs_callback
(
num_trials
)
# pylint: disable=not-callable
self
.
request_trial_jobs_callback
(
num_trials
)
# pylint: disable=not-callable
def
handle_update_search_space
(
self
,
data
):
def
handle_update_search_space
(
self
,
data
):
_logger
.
info
(
'Received search space: %s'
,
data
)
_logger
.
debug
(
'Received search space: %s'
,
data
)
self
.
search_space
=
data
self
.
search_space
=
data
def
handle_trial_end
(
self
,
data
):
def
handle_trial_end
(
self
,
data
):
_logger
.
info
(
'Trial end: %s'
,
data
)
_logger
.
debug
(
'Trial end: %s'
,
data
)
self
.
trial_end_callback
(
json_loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
self
.
trial_end_callback
(
json_loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
def
handle_report_metric_data
(
self
,
data
):
_logger
.
info
(
'Metric reported: %s'
,
data
)
_logger
.
debug
(
'Metric reported: %s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
raise
ValueError
(
'Request parameter not supported'
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
nni/retiarii/integration_api.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
json
from
typing
import
NewType
,
Any
from
typing
import
NewType
,
Any
...
...
nni/retiarii/mutator.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
)
from
typing
import
(
Any
,
Iterable
,
List
,
Optional
)
from
.graph
import
Model
from
.graph
import
Model
...
...
nni/retiarii/nn/pytorch/api.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
collections
import
OrderedDict
from
collections
import
OrderedDict
from
typing
import
Any
,
List
,
Union
,
Dict
from
typing
import
Any
,
List
,
Union
,
Dict
import
warnings
import
warnings
...
...
nni/retiarii/nn/pytorch/mutator.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
List
,
Optional
,
Tuple
from
...mutator
import
Mutator
from
...mutator
import
Mutator
...
...
nni/retiarii/nn/pytorch/nn.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
...
nni/retiarii/oneshot/__init__.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.interface
import
BaseOneShotTrainer
from
.interface
import
BaseOneShotTrainer
nni/retiarii/oneshot/interface.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
import
abc
from
typing
import
Any
from
typing
import
Any
...
...
nni/retiarii/oneshot/pytorch/__init__.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.darts
import
DartsTrainer
from
.darts
import
DartsTrainer
from
.enas
import
EnasTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
from
.proxyless
import
ProxylessTrainer
...
...
nni/retiarii/operation.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
Dict
,
List
)
from
typing
import
(
Any
,
Dict
,
List
)
from
.
import
debug_configs
from
.
import
debug_configs
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
..operation
import
TensorFlowOperation
from
..operation
import
TensorFlowOperation
...
...
nni/retiarii/operation_def/torch_op_def.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
typing
import
(
Any
,
List
)
from
typing
import
(
Any
,
List
)
import
torch
import
torch
...
@@ -369,7 +372,6 @@ class TensorOps(PyTorchOperation):
...
@@ -369,7 +372,6 @@ class TensorOps(PyTorchOperation):
return
TensorOpExceptions
[
self
.
type
](
output
,
inputs
)
return
TensorOpExceptions
[
self
.
type
](
output
,
inputs
)
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
op_name
=
self
.
type
.
split
(
'::'
)[
-
1
]
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
+
1
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
args_str
=
', '
.
join
([
f
'
{
name
}
=
{
inputs
[
i
+
1
]
}
'
for
i
,
(
name
,
t
,
default
)
in
enumerate
(
matched_args
)])
print
(
args_str
)
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
args_str
}
)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.
{
op_name
}
(
{
args_str
}
)'
class
TorchOps
(
PyTorchOperation
):
class
TorchOps
(
PyTorchOperation
):
...
...
nni/retiarii/serializer.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
import
abc
import
functools
import
functools
import
inspect
import
inspect
...
...
nni/retiarii/strategy/__init__.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.base
import
BaseStrategy
from
.base
import
BaseStrategy
from
.bruteforce
import
Random
,
GridSearch
from
.bruteforce
import
Random
,
GridSearch
from
.evolution
import
RegularizedEvolution
from
.evolution
import
RegularizedEvolution
...
...
nni/retiarii/strategy/base.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
abc
import
abc
from
typing
import
List
from
typing
import
List
...
...
nni/retiarii/strategy/bruteforce.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
copy
import
copy
import
itertools
import
itertools
import
logging
import
logging
...
@@ -36,7 +39,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
...
@@ -36,7 +39,7 @@ def random_generator(search_space: Dict[Any, List[Any]], dedup=True, retries=500
history
.
add
(
selected
)
history
.
add
(
selected
)
break
break
if
retry_count
+
1
==
retries
:
if
retry_count
+
1
==
retries
:
_logger
.
info
(
'Random generation has run out of patience. There is nothing to search. Exiting.'
)
_logger
.
debug
(
'Random generation has run out of patience. There is nothing to search. Exiting.'
)
return
return
yield
{
key
:
value
for
key
,
value
in
zip
(
keys
,
selected
)}
yield
{
key
:
value
for
key
,
value
in
zip
(
keys
,
selected
)}
...
@@ -58,7 +61,7 @@ class GridSearch(BaseStrategy):
...
@@ -58,7 +61,7 @@ class GridSearch(BaseStrategy):
def
run
(
self
,
base_model
,
applied_mutators
):
def
run
(
self
,
base_model
,
applied_mutators
):
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
grid_generator
(
search_space
,
shuffle
=
self
.
shuffle
):
for
sample
in
grid_generator
(
search_space
,
shuffle
=
self
.
shuffle
):
_logger
.
info
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
if
query_available_resources
()
<=
0
:
if
query_available_resources
()
<=
0
:
time
.
sleep
(
self
.
_polling_interval
)
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
...
@@ -101,7 +104,7 @@ class Random(BaseStrategy):
...
@@ -101,7 +104,7 @@ class Random(BaseStrategy):
model
=
base_model
model
=
base_model
for
mutator
in
applied_mutators
:
for
mutator
in
applied_mutators
:
model
=
mutator
.
apply
(
model
)
model
=
mutator
.
apply
(
model
)
_logger
.
info
(
'New model created. Applied mutators are: %s'
,
str
(
applied_mutators
))
_logger
.
debug
(
'New model created. Applied mutators are: %s'
,
str
(
applied_mutators
))
submit_models
(
model
)
submit_models
(
model
)
else
:
else
:
time
.
sleep
(
self
.
_polling_interval
)
time
.
sleep
(
self
.
_polling_interval
)
...
@@ -109,7 +112,7 @@ class Random(BaseStrategy):
...
@@ -109,7 +112,7 @@ class Random(BaseStrategy):
_logger
.
info
(
'Random search running in fixed size mode. Dedup: %s.'
,
'on'
if
self
.
dedup
else
'off'
)
_logger
.
info
(
'Random search running in fixed size mode. Dedup: %s.'
,
'on'
if
self
.
dedup
else
'off'
)
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
search_space
=
dry_run_for_search_space
(
base_model
,
applied_mutators
)
for
sample
in
random_generator
(
search_space
,
dedup
=
self
.
dedup
):
for
sample
in
random_generator
(
search_space
,
dedup
=
self
.
dedup
):
_logger
.
info
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
_logger
.
debug
(
'New model created. Waiting for resource. %s'
,
str
(
sample
))
if
query_available_resources
()
<=
0
:
if
query_available_resources
()
<=
0
:
time
.
sleep
(
self
.
_polling_interval
)
time
.
sleep
(
self
.
_polling_interval
)
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
submit_models
(
get_targeted_model
(
base_model
,
applied_mutators
,
sample
))
nni/retiarii/strategy/evolution.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
collections
import
dataclasses
import
dataclasses
import
logging
import
logging
...
@@ -122,7 +125,7 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -122,7 +125,7 @@ class RegularizedEvolution(BaseStrategy):
break
break
def
_submit_config
(
self
,
config
,
base_model
,
mutators
):
def
_submit_config
(
self
,
config
,
base_model
,
mutators
):
_logger
.
info
(
'Model submitted to running queue: %s'
,
config
)
_logger
.
debug
(
'Model submitted to running queue: %s'
,
config
)
model
=
get_targeted_model
(
base_model
,
mutators
,
config
)
model
=
get_targeted_model
(
base_model
,
mutators
,
config
)
submit_models
(
model
)
submit_models
(
model
)
self
.
_running_models
.
append
((
config
,
model
))
self
.
_running_models
.
append
((
config
,
model
))
...
@@ -138,7 +141,7 @@ class RegularizedEvolution(BaseStrategy):
...
@@ -138,7 +141,7 @@ class RegularizedEvolution(BaseStrategy):
metric
=
model
.
metric
metric
=
model
.
metric
if
metric
is
not
None
:
if
metric
is
not
None
:
individual
=
Individual
(
config
,
metric
)
individual
=
Individual
(
config
,
metric
)
_logger
.
info
(
'Individual created: %s'
,
str
(
individual
))
_logger
.
debug
(
'Individual created: %s'
,
str
(
individual
))
self
.
_population
.
append
(
individual
)
self
.
_population
.
append
(
individual
)
if
len
(
self
.
_population
)
>
self
.
population_size
:
if
len
(
self
.
_population
)
>
self
.
population_size
:
self
.
_population
.
popleft
()
self
.
_population
.
popleft
()
...
...
nni/retiarii/strategy/tpe_strategy.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
logging
import
logging
import
time
import
time
...
@@ -55,7 +58,7 @@ class TPEStrategy(BaseStrategy):
...
@@ -55,7 +58,7 @@ class TPEStrategy(BaseStrategy):
avail_resource
=
query_available_resources
()
avail_resource
=
query_available_resources
()
if
avail_resource
>
0
:
if
avail_resource
>
0
:
model
=
base_model
model
=
base_model
_logger
.
info
(
'New model created. Applied mutators: %s'
,
str
(
applied_mutators
))
_logger
.
debug
(
'New model created. Applied mutators: %s'
,
str
(
applied_mutators
))
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
self
.
tpe_sampler
.
generate_samples
(
self
.
model_id
)
for
mutator
in
applied_mutators
:
for
mutator
in
applied_mutators
:
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
mutator
.
bind_sampler
(
self
.
tpe_sampler
)
...
...
nni/retiarii/strategy/utils.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
collections
import
collections
from
typing
import
Dict
,
Any
,
List
from
typing
import
Dict
,
Any
,
List
from
..graph
import
Model
from
..graph
import
Model
...
...
nni/retiarii/trial_entry.py
View file @
e457047c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
"""
Entrypoint for trials.
Entrypoint for trials.
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment