Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bc6d8796
Unverified
Commit
bc6d8796
authored
Aug 01, 2022
by
Yuge Zhang
Committed by
GitHub
Aug 01, 2022
Browse files
Promote Retiarii to NAS (step 2) - update imports (#5025)
parent
867871b2
Changes
173
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
113 additions
and
84 deletions
+113
-84
nni/nas/execution/pytorch/benchmark.py
nni/nas/execution/pytorch/benchmark.py
+2
-4
nni/nas/execution/pytorch/cgo/__init__.py
nni/nas/execution/pytorch/cgo/__init__.py
+4
-0
nni/nas/execution/pytorch/cgo/engine.py
nni/nas/execution/pytorch/cgo/engine.py
+13
-9
nni/nas/execution/pytorch/cgo/logical_optimizer/__init__.py
nni/nas/execution/pytorch/cgo/logical_optimizer/__init__.py
+0
-0
nni/nas/execution/pytorch/cgo/logical_optimizer/logical_plan.py
...s/execution/pytorch/cgo/logical_optimizer/logical_plan.py
+2
-2
nni/nas/execution/pytorch/cgo/logical_optimizer/opt_dedup_input.py
...xecution/pytorch/cgo/logical_optimizer/opt_dedup_input.py
+3
-3
nni/nas/execution/pytorch/codegen.py
nni/nas/execution/pytorch/codegen.py
+5
-5
nni/nas/execution/pytorch/converter/__init__.py
nni/nas/execution/pytorch/converter/__init__.py
+4
-0
nni/nas/execution/pytorch/converter/graph_gen.py
nni/nas/execution/pytorch/converter/graph_gen.py
+3
-5
nni/nas/execution/pytorch/converter/utils.py
nni/nas/execution/pytorch/converter/utils.py
+1
-2
nni/nas/execution/pytorch/graph.py
nni/nas/execution/pytorch/graph.py
+11
-8
nni/nas/execution/pytorch/op_def.py
nni/nas/execution/pytorch/op_def.py
+1
-1
nni/nas/execution/pytorch/simplified.py
nni/nas/execution/pytorch/simplified.py
+7
-5
nni/nas/execution/tensorflow/__init__.py
nni/nas/execution/tensorflow/__init__.py
+0
-0
nni/nas/execution/tensorflow/op_def.py
nni/nas/execution/tensorflow/op_def.py
+1
-1
nni/nas/execution/trial_entry.py
nni/nas/execution/trial_entry.py
+10
-7
nni/nas/experiment/__init__.py
nni/nas/experiment/__init__.py
+8
-0
nni/nas/experiment/config/__init__.py
nni/nas/experiment/config/__init__.py
+5
-0
nni/nas/experiment/pytorch.py
nni/nas/experiment/pytorch.py
+31
-32
nni/nas/experiment/tensorflow.py
nni/nas/experiment/tensorflow.py
+2
-0
No files found.
nni/nas/execution/pytorch/benchmark.py
View file @
bc6d8796
...
@@ -5,10 +5,8 @@ import os
...
@@ -5,10 +5,8 @@ import os
import
random
import
random
from
typing
import
Dict
,
Any
,
List
,
Optional
,
Union
,
Tuple
,
Callable
,
Iterable
,
cast
from
typing
import
Dict
,
Any
,
List
,
Optional
,
Union
,
Tuple
,
Callable
,
Iterable
,
cast
from
..graph
import
Model
from
nni.nas.execution.common
import
Model
,
receive_trial_parameters
,
get_mutation_dict
from
..integration_api
import
receive_trial_parameters
from
.graph
import
BaseExecutionEngine
from
.base
import
BaseExecutionEngine
from
.utils
import
get_mutation_dict
class
BenchmarkGraphData
:
class
BenchmarkGraphData
:
...
...
nni/nas/execution/pytorch/cgo/__init__.py
0 → 100644
View file @
bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.engine
import
*
nni/nas/execution/pytorch/cgo.py
→
nni/nas/execution/pytorch/cgo
/engine
.py
View file @
bc6d8796
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
from
__future__
import
annotations
from
__future__
import
annotations
__all__
=
[
'CGOExecutionEngine'
,
'TrialSubmission'
]
import
logging
import
logging
import
os
import
os
import
random
import
random
...
@@ -14,17 +16,19 @@ from dataclasses import dataclass
...
@@ -14,17 +16,19 @@ from dataclasses import dataclass
from
nni.common.device
import
GPUDevice
,
Device
from
nni.common.device
import
GPUDevice
,
Device
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.retiarii.integration
import
RetiariiAdvisor
from
nni.nas
import
utils
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
nni.nas.execution.common
import
(
from
..
import
codegen
,
utils
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
,
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Node
Model
,
ModelStatus
,
MetricData
,
Node
,
from
..integration_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
RetiariiAdvisor
,
send_trial
,
receive_trial_parameters
,
get_advisor
,
)
from
nni.nas.execution.pytorch
import
codegen
from
nni.nas.evaluator.pytorch.lightning
import
Lightning
from
nni.nas.evaluator.pytorch.cgo.evaluator
import
_MultiModelSupervisedLearningModule
from
nni.nas.execution.pytorch.graph
import
BaseGraphData
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
AbstractLogicalNode
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
AbstractLogicalNode
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
from
..evaluator.pytorch.lightning
import
Lightning
from
..evaluator.pytorch.cgo.evaluator
import
_MultiModelSupervisedLearningModule
from
.base
import
BaseGraphData
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
...
nni/nas/execution/pytorch/cgo/logical_optimizer/__init__.py
0 → 100644
View file @
bc6d8796
nni/nas/execution/pytorch/cgo/logical_optimizer/logical_plan.py
View file @
bc6d8796
...
@@ -7,8 +7,8 @@ from typing import Dict, Tuple, Any
...
@@ -7,8 +7,8 @@ from typing import Dict, Tuple, Any
from
nni.retiarii.utils
import
uid
from
nni.retiarii.utils
import
uid
from
nni.common.device
import
Device
,
CPUDevice
from
nni.common.device
import
Device
,
CPUDevice
from
..
.graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
nni.nas.execution.common
.graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
from
nni.nas.execution.common.graph_op
import
Operation
,
_IOPseudoOperation
class
AbstractLogicalNode
(
Node
):
class
AbstractLogicalNode
(
Node
):
...
...
nni/nas/execution/pytorch/cgo/logical_optimizer/opt_dedup_input.py
View file @
bc6d8796
...
@@ -3,11 +3,11 @@
...
@@ -3,11 +3,11 @@
from
typing
import
List
,
Dict
,
Tuple
from
typing
import
List
,
Dict
,
Tuple
from
nni.
retiarii
.utils
import
uid
from
nni.
nas
.utils
import
uid
from
nni.
retiarii
.evaluator.pytorch.cgo.evaluator
import
MultiModelSupervisedLearningModule
from
nni.
nas
.evaluator.pytorch.cgo.evaluator
import
MultiModelSupervisedLearningModule
from
nni.common.device
import
GPUDevice
from
nni.common.device
import
GPUDevice
from
..
.graph
import
Graph
,
Model
,
Node
from
nni.nas.execution.common
.graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
OriginNode
)
OriginNode
)
...
...
nni/nas/execution/pytorch/codegen.py
View file @
bc6d8796
...
@@ -7,12 +7,12 @@ import logging
...
@@ -7,12 +7,12 @@ import logging
import
re
import
re
from
typing
import
Dict
,
List
,
Tuple
,
Any
,
cast
from
typing
import
Dict
,
List
,
Tuple
,
Any
,
cast
from
nni.retiarii.operation
import
PyTorchOperation
from
nni.retiarii.operation_def.torch_op_def
import
ToDevice
from
nni.retiarii.utils
import
STATE_DICT_PY_MAPPING
from
nni.common.device
import
Device
,
GPUDevice
from
nni.common.device
import
Device
,
GPUDevice
from
nni.nas.execution.common.graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
nni.nas.execution.common.graph_op
import
PyTorchOperation
from
nni.nas.utils
import
STATE_DICT_PY_MAPPING
from
.
.graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
.
op_def
import
ToDevice
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -215,7 +215,7 @@ import torch.nn as nn
...
@@ -215,7 +215,7 @@ import torch.nn as nn
import torch.nn.functional as F
import torch.nn.functional as F
import torch.optim as optim
import torch.optim as optim
import nni.
retiarii
.nn.pytorch
import nni.
nas
.nn.pytorch
{}
{}
...
...
nni/nas/execution/pytorch/converter/__init__.py
0 → 100644
View file @
bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.graph_gen
import
convert_to_graph
nni/nas/execution/pytorch/converter/graph_gen.py
View file @
bc6d8796
...
@@ -5,11 +5,9 @@ import re
...
@@ -5,11 +5,9 @@ import re
import
torch
import
torch
from
..graph
import
Graph
,
Model
,
Node
from
nni.nas.execution.common
import
Graph
,
Model
,
Node
,
Cell
,
Operation
from
..nn.pytorch
import
InputChoice
,
Placeholder
,
LayerChoice
from
nni.nas.nn.pytorch
import
InputChoice
,
Placeholder
,
LayerChoice
from
..operation
import
Cell
,
Operation
from
nni.nas.utils
import
get_init_parameters_or_fail
,
get_importable_name
from
..serializer
import
get_init_parameters_or_fail
from
..utils
import
get_importable_name
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
from
.utils
import
(
from
.utils
import
(
_convert_name
,
build_full_name
,
_without_shape_info
,
_convert_name
,
build_full_name
,
_without_shape_info
,
...
...
nni/nas/execution/pytorch/converter/utils.py
View file @
bc6d8796
...
@@ -5,8 +5,7 @@ from typing import Optional
...
@@ -5,8 +5,7 @@ from typing import Optional
from
typing_extensions
import
TypeGuard
from
typing_extensions
import
TypeGuard
from
..operation
import
Cell
from
nni.nas.execution.common
import
Cell
,
Model
,
Graph
,
Node
,
Edge
from
..graph
import
Model
,
Graph
,
Node
,
Edge
def
build_full_name
(
prefix
,
name
,
seq
=
None
):
def
build_full_name
(
prefix
,
name
,
seq
=
None
):
...
...
nni/nas/execution/pytorch/graph.py
View file @
bc6d8796
...
@@ -3,6 +3,8 @@
...
@@ -3,6 +3,8 @@
from
__future__
import
annotations
from
__future__
import
annotations
__all__
=
[
'BaseGraphData'
,
'BaseExecutionEngine'
]
import
logging
import
logging
import
os
import
os
import
random
import
random
...
@@ -10,13 +12,14 @@ import string
...
@@ -10,13 +12,14 @@ import string
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
typing
import
Any
,
Dict
,
Iterable
,
List
from
nni.experiment
import
rest
from
nni.experiment
import
rest
from
nni.retiarii.integration
import
RetiariiAdvisor
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
from
nni.nas.execution.common
import
(
from
.utils
import
get_mutation_summary
AbstractExecutionEngine
,
AbstractGraphListener
,
RetiariiAdvisor
,
get_mutation_summary
,
from
..
import
codegen
,
utils
Model
,
ModelStatus
,
MetricData
,
Evaluator
,
from
..graph
import
Model
,
ModelStatus
,
MetricData
,
Evaluator
send_trial
,
receive_trial_parameters
,
get_advisor
from
..integration_api
import
send_trial
,
receive_trial_parameters
,
get_advisor
)
from
nni.nas.utils
import
import_
from
.codegen
import
model_to_pytorch_script
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -146,7 +149,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -146,7 +149,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
def
pack_model_data
(
cls
,
model
:
Model
)
->
Any
:
mutation_summary
=
get_mutation_summary
(
model
)
mutation_summary
=
get_mutation_summary
(
model
)
assert
model
.
evaluator
is
not
None
,
'Model evaluator can not be None'
assert
model
.
evaluator
is
not
None
,
'Model evaluator can not be None'
return
BaseGraphData
(
codegen
.
pytorch
.
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
# type: ignore
return
BaseGraphData
(
model_to_pytorch_script
(
model
),
model
.
evaluator
,
mutation_summary
)
# type: ignore
@
classmethod
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
def
trial_execute_graph
(
cls
)
->
None
:
...
@@ -159,6 +162,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
...
@@ -159,6 +162,6 @@ class BaseExecutionEngine(AbstractExecutionEngine):
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
os
.
makedirs
(
os
.
path
.
dirname
(
file_name
),
exist_ok
=
True
)
with
open
(
file_name
,
'w'
)
as
f
:
with
open
(
file_name
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
f
.
write
(
graph_data
.
model_script
)
model_cls
=
utils
.
import_
(
f
'_generated_model.
{
random_str
}
._model'
)
model_cls
=
import_
(
f
'_generated_model.
{
random_str
}
._model'
)
graph_data
.
evaluator
.
_execute
(
model_cls
)
graph_data
.
evaluator
.
_execute
(
model_cls
)
os
.
remove
(
file_name
)
os
.
remove
(
file_name
)
nni/nas/execution/pytorch/op_def.py
View file @
bc6d8796
...
@@ -8,7 +8,7 @@ from typing import (Any, Dict, List)
...
@@ -8,7 +8,7 @@ from typing import (Any, Dict, List)
import
torch
import
torch
import
torch.nn.functional
as
nn_functional
import
torch.nn.functional
as
nn_functional
from
..operati
on
import
PyTorchOperation
from
nni.nas.execution.comm
on
import
PyTorchOperation
mem_format
=
[
mem_format
=
[
...
...
nni/nas/execution/pytorch/simplified.py
View file @
bc6d8796
...
@@ -5,11 +5,13 @@ from typing import Dict, Any, Type, cast
...
@@ -5,11 +5,13 @@ from typing import Dict, Any, Type, cast
import
torch.nn
as
nn
import
torch.nn
as
nn
from
..graph
import
Evaluator
,
Model
from
nni.nas.execution.common
import
(
from
..integration_api
import
receive_trial_parameters
Model
,
receive_trial_parameters
,
from
..utils
import
ContextStack
get_mutation_dict
,
mutation_dict_to_summary
from
.base
import
BaseExecutionEngine
)
from
.utils
import
get_mutation_dict
,
mutation_dict_to_summary
from
nni.nas.evaluator
import
Evaluator
from
nni.nas.utils
import
ContextStack
from
.graph
import
BaseExecutionEngine
class
PythonGraphData
:
class
PythonGraphData
:
...
...
nni/nas/execution/tensorflow/__init__.py
0 → 100644
View file @
bc6d8796
nni/nas/execution/tensorflow/op_def.py
View file @
bc6d8796
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
..operati
on
import
TensorFlowOperation
from
nni.nas.execution.comm
on
import
TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
class
Conv2D
(
TensorFlowOperation
):
...
...
nni/nas/execution/trial_entry.py
View file @
bc6d8796
...
@@ -3,28 +3,31 @@
...
@@ -3,28 +3,31 @@
"""
"""
Entrypoint for trials.
Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
"""
"""
import
argparse
import
argparse
if
__name__
==
'__
main
__'
:
def
main
()
:
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'exec'
,
choices
=
[
'base'
,
'py'
,
'cgo'
,
'benchmark'
])
parser
.
add_argument
(
'exec'
,
choices
=
[
'base'
,
'py'
,
'cgo'
,
'benchmark'
])
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
if
args
.
exec
==
'base'
:
if
args
.
exec
==
'base'
:
from
.
execution.base
import
BaseExecutionEngine
from
.
pytorch.graph
import
BaseExecutionEngine
engine
=
BaseExecutionEngine
engine
=
BaseExecutionEngine
elif
args
.
exec
==
'cgo'
:
elif
args
.
exec
==
'cgo'
:
from
.
execution.cgo_engine
import
CGOExecutionEngine
from
.
pytorch.cgo
import
CGOExecutionEngine
engine
=
CGOExecutionEngine
engine
=
CGOExecutionEngine
elif
args
.
exec
==
'py'
:
elif
args
.
exec
==
'py'
:
from
.
execution.python
import
PurePythonExecutionEngine
from
.
pytorch.simplified
import
PurePythonExecutionEngine
engine
=
PurePythonExecutionEngine
engine
=
PurePythonExecutionEngine
elif
args
.
exec
==
'benchmark'
:
elif
args
.
exec
==
'benchmark'
:
from
.
execution
.benchmark
import
BenchmarkExecutionEngine
from
.
pytorch
.benchmark
import
BenchmarkExecutionEngine
engine
=
BenchmarkExecutionEngine
engine
=
BenchmarkExecutionEngine
else
:
else
:
raise
ValueError
(
f
'Unrecognized benchmark name:
{
args
.
exec
}
'
)
raise
ValueError
(
f
'Unrecognized benchmark name:
{
args
.
exec
}
'
)
engine
.
trial_execute_graph
()
engine
.
trial_execute_graph
()
if
__name__
==
'__main__'
:
main
()
nni/nas/experiment/__init__.py
0 → 100644
View file @
bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
nni.common.framework
import
shortcut_framework
shortcut_framework
(
__name__
)
del
shortcut_framework
nni/nas/experiment/config/__init__.py
0 → 100644
View file @
bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.experiment_config
import
*
from
.engine_config
import
*
\ No newline at end of file
nni/nas/experiment/pytorch.py
View file @
bc6d8796
...
@@ -3,11 +3,13 @@
...
@@ -3,11 +3,13 @@
from
__future__
import
annotations
from
__future__
import
annotations
__all__
=
[
'RetiariiExeConfig'
,
'RetiariiExperiment'
,
'preprocess_model'
,
'debug_mutated_model'
]
import
logging
import
logging
import
warnings
import
warnings
from
threading
import
Thread
from
threading
import
Thread
from
typing
import
Any
,
List
,
Union
,
cast
from
typing
import
Any
,
List
,
cast
import
colorama
import
colorama
...
@@ -16,32 +18,27 @@ import torch.nn as nn
...
@@ -16,32 +18,27 @@ import torch.nn as nn
from
nni.experiment
import
Experiment
,
RunMode
from
nni.experiment
import
Experiment
,
RunMode
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.experiment.config.training_services
import
RemoteConfig
from
nni.nas.execution
import
list_models
,
set_execution_engine
from
nni.nas.execution.common
import
RetiariiAdvisor
,
get_mutation_dict
from
nni.nas.execution.pytorch.codegen
import
model_to_pytorch_script
from
nni.nas.execution.pytorch.converter
import
convert_to_graph
from
nni.nas.execution.pytorch.converter.graph_gen
import
GraphConverterWithShape
from
nni.nas.evaluator
import
Evaluator
from
nni.nas.mutable
import
Mutator
from
nni.nas.nn.pytorch.mutator
import
(
extract_mutation_from_pt_module
,
process_inline_mutation
,
process_evaluator_mutations
,
process_oneshot_mutations
)
from
nni.nas.utils
import
is_model_wrapped
from
nni.nas.strategy
import
BaseStrategy
from
nni.nas.strategy.utils
import
dry_run_for_formatted_search_space
from
.config
import
(
from
.config
import
(
RetiariiExeConfig
,
OneshotEngineConfig
,
BaseEngineConfig
,
RetiariiExeConfig
,
OneshotEngineConfig
,
BaseEngineConfig
,
PyEngineConfig
,
CgoEngineConfig
,
BenchmarkEngineConfig
PyEngineConfig
,
CgoEngineConfig
,
BenchmarkEngineConfig
)
)
from
..codegen.pytorch
import
model_to_pytorch_script
from
..converter
import
convert_to_graph
from
..converter.graph_gen
import
GraphConverterWithShape
from
..execution
import
list_models
,
set_execution_engine
from
..execution.utils
import
get_mutation_dict
from
..graph
import
Evaluator
from
..integration
import
RetiariiAdvisor
from
..mutator
import
Mutator
from
..nn.pytorch.mutator
import
(
extract_mutation_from_pt_module
,
process_inline_mutation
,
process_evaluator_mutations
,
process_oneshot_mutations
)
from
..oneshot.interface
import
BaseOneShotTrainer
from
..serializer
import
is_model_wrapped
from
..strategy
import
BaseStrategy
from
..strategy.utils
import
dry_run_for_formatted_search_space
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
'RetiariiExperiment'
]
def
preprocess_model
(
base_model
,
evaluator
,
applied_mutators
,
full_ir
=
True
,
dummy_input
=
None
,
oneshot
=
False
):
def
preprocess_model
(
base_model
,
evaluator
,
applied_mutators
,
full_ir
=
True
,
dummy_input
=
None
,
oneshot
=
False
):
# TODO: this logic might need to be refactored into execution engine
# TODO: this logic might need to be refactored into execution engine
if
oneshot
:
if
oneshot
:
...
@@ -97,7 +94,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators):
...
@@ -97,7 +94,7 @@ def debug_mutated_model(base_model, evaluator, applied_mutators):
a list of mutators that will be applied on the base model for generating a new model
a list of mutators that will be applied on the base model for generating a new model
"""
"""
base_model_ir
,
applied_mutators
=
preprocess_model
(
base_model
,
evaluator
,
applied_mutators
)
base_model_ir
,
applied_mutators
=
preprocess_model
(
base_model
,
evaluator
,
applied_mutators
)
from
.
.strategy.
local_debug_strategy
import
_LocalDebugStrategy
from
nni.nas
.strategy.
debug
import
_LocalDebugStrategy
strategy
=
_LocalDebugStrategy
()
strategy
=
_LocalDebugStrategy
()
strategy
.
run
(
base_model_ir
,
applied_mutators
)
strategy
.
run
(
base_model_ir
,
applied_mutators
)
_logger
.
info
(
'local debug completed!'
)
_logger
.
info
(
'local debug completed!'
)
...
@@ -174,10 +171,10 @@ class RetiariiExperiment(Experiment):
...
@@ -174,10 +171,10 @@ class RetiariiExperiment(Experiment):
"""
"""
def
__init__
(
self
,
base_model
:
nn
.
Module
,
def
__init__
(
self
,
base_model
:
nn
.
Module
,
evaluator
:
Union
[
BaseOneShotTrainer
,
Evaluator
]
=
cast
(
Evaluator
,
None
),
evaluator
:
Evaluator
=
cast
(
Evaluator
,
None
),
applied_mutators
:
List
[
Mutator
]
=
cast
(
List
[
Mutator
],
None
),
applied_mutators
:
List
[
Mutator
]
=
cast
(
List
[
Mutator
],
None
),
strategy
:
BaseStrategy
=
cast
(
BaseStrategy
,
None
),
strategy
:
BaseStrategy
=
cast
(
BaseStrategy
,
None
),
trainer
:
BaseOneShotTrainer
=
cast
(
BaseOneShotTrainer
,
None
)
)
:
trainer
:
Any
=
None
):
super
().
__init__
(
None
)
super
().
__init__
(
None
)
self
.
config
:
RetiariiExeConfig
=
cast
(
RetiariiExeConfig
,
None
)
self
.
config
:
RetiariiExeConfig
=
cast
(
RetiariiExeConfig
,
None
)
...
@@ -190,7 +187,7 @@ class RetiariiExperiment(Experiment):
...
@@ -190,7 +187,7 @@ class RetiariiExperiment(Experiment):
raise
ValueError
(
'Evaluator should not be none.'
)
raise
ValueError
(
'Evaluator should not be none.'
)
self
.
base_model
=
base_model
self
.
base_model
=
base_model
self
.
evaluator
:
Union
[
Evaluator
,
BaseOneShotTrainer
]
=
evaluator
self
.
evaluator
:
Evaluator
=
evaluator
self
.
applied_mutators
=
applied_mutators
self
.
applied_mutators
=
applied_mutators
self
.
strategy
=
strategy
self
.
strategy
=
strategy
...
@@ -222,10 +219,10 @@ class RetiariiExperiment(Experiment):
...
@@ -222,10 +219,10 @@ class RetiariiExperiment(Experiment):
def
_create_execution_engine
(
self
,
config
:
RetiariiExeConfig
)
->
None
:
def
_create_execution_engine
(
self
,
config
:
RetiariiExeConfig
)
->
None
:
#TODO: we will probably need a execution engine factory to make this clean and elegant
#TODO: we will probably need a execution engine factory to make this clean and elegant
if
isinstance
(
config
.
execution_engine
,
BaseEngineConfig
):
if
isinstance
(
config
.
execution_engine
,
BaseEngineConfig
):
from
.
.execution.
base
import
BaseExecutionEngine
from
nni.nas
.execution.
pytorch.graph
import
BaseExecutionEngine
engine
=
BaseExecutionEngine
(
self
.
port
,
self
.
url_prefix
)
engine
=
BaseExecutionEngine
(
self
.
port
,
self
.
url_prefix
)
elif
isinstance
(
config
.
execution_engine
,
CgoEngineConfig
):
elif
isinstance
(
config
.
execution_engine
,
CgoEngineConfig
):
from
.
.execution.
cgo_engine
import
CGOExecutionEngine
from
nni.nas
.execution.
pytorch.cgo
import
CGOExecutionEngine
assert
not
isinstance
(
config
.
training_service
,
list
)
\
assert
not
isinstance
(
config
.
training_service
,
list
)
\
and
config
.
training_service
.
platform
==
'remote'
,
\
and
config
.
training_service
.
platform
==
'remote'
,
\
...
@@ -238,10 +235,10 @@ class RetiariiExperiment(Experiment):
...
@@ -238,10 +235,10 @@ class RetiariiExperiment(Experiment):
rest_port
=
self
.
port
,
rest_port
=
self
.
port
,
rest_url_prefix
=
self
.
url_prefix
)
rest_url_prefix
=
self
.
url_prefix
)
elif
isinstance
(
config
.
execution_engine
,
PyEngineConfig
):
elif
isinstance
(
config
.
execution_engine
,
PyEngineConfig
):
from
.
.execution.pyt
hon
import
PurePythonExecutionEngine
from
nni.nas
.execution.pyt
orch.simplified
import
PurePythonExecutionEngine
engine
=
PurePythonExecutionEngine
(
self
.
port
,
self
.
url_prefix
)
engine
=
PurePythonExecutionEngine
(
self
.
port
,
self
.
url_prefix
)
elif
isinstance
(
config
.
execution_engine
,
BenchmarkEngineConfig
):
elif
isinstance
(
config
.
execution_engine
,
BenchmarkEngineConfig
):
from
.
.execution.benchmark
import
BenchmarkExecutionEngine
from
nni.nas
.execution.
pytorch.
benchmark
import
BenchmarkExecutionEngine
assert
config
.
execution_engine
.
benchmark
is
not
None
,
\
assert
config
.
execution_engine
.
benchmark
is
not
None
,
\
'"benchmark" must be set when benchmark execution engine is used.'
'"benchmark" must be set when benchmark execution engine is used.'
engine
=
BenchmarkExecutionEngine
(
config
.
execution_engine
.
benchmark
)
engine
=
BenchmarkExecutionEngine
(
config
.
execution_engine
.
benchmark
)
...
@@ -265,12 +262,13 @@ class RetiariiExperiment(Experiment):
...
@@ -265,12 +262,13 @@ class RetiariiExperiment(Experiment):
Run the experiment.
Run the experiment.
This function will block until experiment finish or error.
This function will block until experiment finish or error.
"""
"""
from
nni.retiarii.oneshot.interface
import
BaseOneShotTrainer
if
isinstance
(
self
.
evaluator
,
BaseOneShotTrainer
):
if
isinstance
(
self
.
evaluator
,
BaseOneShotTrainer
):
# TODO: will throw a deprecation warning soon
warnings
.
warn
(
'You are using the old implementation of one-shot algos based on One-shot trainer. '
# warnings.warn('You are using the old implementation of one-shot algos based on One-shot trainer. '
'We will try to convert this trainer to our new implementation to run the algorithm. '
# 'We will try to convert this trainer to our new implementation to run the algorithm. '
'In case you want to stick to the old implementation, '
# 'In case you want to stick to the old implementation, '
'please consider using ``trainer.fit()`` instead of experiment.'
,
DeprecationWarning
)
# 'please consider using ``trainer.fit()`` instead of experiment.', DeprecationWarning)
self
.
evaluator
.
fit
()
self
.
evaluator
.
fit
()
return
return
...
@@ -344,6 +342,7 @@ class RetiariiExperiment(Experiment):
...
@@ -344,6 +342,7 @@ class RetiariiExperiment(Experiment):
config
=
self
.
config
.
canonical_copy
()
config
=
self
.
config
.
canonical_copy
()
assert
not
isinstance
(
config
.
execution_engine
,
PyEngineConfig
),
\
assert
not
isinstance
(
config
.
execution_engine
,
PyEngineConfig
),
\
'You should use `dict` formatter when using Python execution engine.'
'You should use `dict` formatter when using Python execution engine.'
from
nni.retiarii.oneshot.interface
import
BaseOneShotTrainer
if
isinstance
(
self
.
evaluator
,
BaseOneShotTrainer
):
if
isinstance
(
self
.
evaluator
,
BaseOneShotTrainer
):
assert
top_k
==
1
,
'Only support top_k is 1 for now.'
assert
top_k
==
1
,
'Only support top_k is 1 for now.'
return
self
.
evaluator
.
export
()
return
self
.
evaluator
.
export
()
...
...
nni/nas/experiment/tensorflow.py
0 → 100644
View file @
bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
Prev
1
2
3
4
5
6
…
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment