Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
59cd3982
Unverified
Commit
59cd3982
authored
Dec 14, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 14, 2020
Browse files
[Retiarii] Coding style improvements for pylint and flake8 (#3190)
parent
593a275c
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
148 additions
and
152 deletions
+148
-152
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+10
-9
nni/retiarii/converter/__init__.py
nni/retiarii/converter/__init__.py
+0
-1
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+29
-18
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+3
-2
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+1
-0
nni/retiarii/converter/visualize.py
nni/retiarii/converter/visualize.py
+3
-1
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+2
-3
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+4
-4
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+16
-17
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+2
-2
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+2
-4
nni/retiarii/execution/logical_optimizer/interface.py
nni/retiarii/execution/logical_optimizer/interface.py
+3
-3
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+12
-15
nni/retiarii/execution/logical_optimizer/opt_batching.py
nni/retiarii/execution/logical_optimizer/opt_batching.py
+0
-10
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+26
-25
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
...etiarii/execution/logical_optimizer/opt_weight_sharing.py
+0
-10
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+16
-11
nni/retiarii/graph.py
nni/retiarii/graph.py
+6
-5
nni/retiarii/integration.py
nni/retiarii/integration.py
+8
-10
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+5
-2
No files found.
nni/retiarii/codegen/pytorch.py
View file @
59cd3982
import
logging
from
typing
import
*
from
typing
import
List
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..operation
import
Operation
,
Cell
_logger
=
logging
.
getLogger
(
__name__
)
def
model_to_pytorch_script
(
model
:
Model
,
placement
=
None
)
->
str
:
def
model_to_pytorch_script
(
model
:
Model
,
placement
=
None
)
->
str
:
graphs
=
[]
total_pkgs
=
set
()
for
name
,
cell
in
model
.
graphs
.
items
():
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
graphs
.
append
(
graph_code
)
total_pkgs
.
update
(
import_pkgs
)
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges:
{}'
.
format
(
edges
))
_logger
.
info
(
'sorted_incoming_edges:
%s'
,
str
(
edges
))
if
not
edges
:
return
[]
_logger
.
info
(
f
'all tail_slots are None:
{
[
edge
.
tail_slot
for
edge
in
edges
]
}
'
)
_logger
.
info
(
'all tail_slots are None:
%s'
,
str
(
[
edge
.
tail_slot
for
edge
in
edges
]
)
)
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
@@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return
edges
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
def
_format_inputs
(
node
:
Node
)
->
List
[
str
]:
edges
=
_sorted_incoming_edges
(
node
)
inputs
=
[]
...
...
@@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
def
_remove_prefix
(
names
,
graph_name
):
"""
variables name (full name space) is too long,
...
...
@@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name):
else
:
return
names
[
len
(
graph_name
):]
if
names
.
startswith
(
graph_name
)
else
names
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
nodes
=
graph
.
topo_sort
()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs
=
set
()
node_codes
=
[]
placement_codes
=
[]
for
node
in
nodes
:
if
node
.
operation
:
pkg_name
=
node
.
operation
.
get_import_pkg
()
...
...
nni/retiarii/converter/__init__.py
View file @
59cd3982
from
.graph_gen
import
convert_to_graph
from
.visualize
import
visualize_model
\ No newline at end of file
nni/retiarii/converter/graph_gen.py
View file @
59cd3982
import
json_tricks
import
logging
import
re
import
torch
from
..graph
import
Graph
,
Node
,
Edge
,
Model
from
..operation
import
Cell
,
Operation
from
..nn.pytorch
import
Placeholder
,
LayerChoice
,
InputChoice
import
torch
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
,
BasicOpsPT
from
.utils
import
build_full_name
,
_convert_name
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -16,6 +15,7 @@ global_seq = 0
global_graph_id
=
0
modules_arg
=
None
def
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
...
...
@@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx
+=
1
def
create_prim_constant_node
(
ir_graph
,
node
,
module_name
):
global
global_seq
attrs
=
{}
...
...
@@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name):
node
.
kind
(),
attrs
)
return
new_node
def
handle_prim_attr_node
(
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
def
_remove_mangle
(
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
def
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
None
):
"""
Parameters
...
...
@@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
...
...
@@ -248,7 +253,8 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
...
...
@@ -350,6 +356,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return
node_index
def
merge_aten_slices
(
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
...
...
@@ -408,13 +415,14 @@ def refine_graph(ir_graph):
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
module
):
global
modules_arg
m_attrs
=
{}
candidates
=
module
.
candidate_ops
choices
=
[]
for
i
,
cand
in
enumerate
(
candidates
)
:
for
cand
in
candidates
:
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
...
...
@@ -423,6 +431,7 @@ def _handle_layerchoice(module):
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
module
):
m_attrs
=
{}
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
...
...
@@ -430,6 +439,7 @@ def _handle_inputchoice(module):
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
...
...
@@ -507,6 +517,7 @@ def convert_module(script_module, module, module_name, ir_model):
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
...
...
nni/retiarii/converter/op_types.py
View file @
59cd3982
...
...
@@ -16,6 +16,7 @@ class OpTypeName(str, Enum):
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
# deal with aten op
BasicOpsPT
=
{
'aten::mean'
:
'Mean'
,
...
...
nni/retiarii/converter/utils.py
View file @
59cd3982
...
...
@@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None):
else
:
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
...
...
nni/retiarii/converter/visualize.py
View file @
59cd3982
import
graphviz
def
convert_to_visualize
(
graph_ir
,
vgraph
):
for
name
,
graph
in
graph_ir
.
items
():
if
name
==
'_training_config'
:
...
...
@@ -33,6 +34,7 @@ def convert_to_visualize(graph_ir, vgraph):
dst
=
cell_node
[
dst
][
0
]
subgraph
.
edge
(
src
,
dst
)
def
visualize_model
(
graph_ir
):
vgraph
=
graphviz
.
Digraph
(
'G'
,
filename
=
'vgraph'
,
format
=
'jpg'
)
convert_to_visualize
(
graph_ir
,
vgraph
)
...
...
nni/retiarii/execution/api.py
View file @
59cd3982
import
time
import
os
import
importlib.util
from
typing
import
*
from
typing
import
List
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
*
from
.interface
import
AbstractExecutionEngine
,
WorkerInfo
from
.listener
import
DefaultListener
_execution_engine
=
None
...
...
nni/retiarii/execution/base.py
View file @
59cd3982
import
logging
from
typing
import
*
from
typing
import
Dict
,
Any
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
...
...
@@ -61,16 +61,16 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
_logger
.
warning
(
'resources:
{}'
.
format
(
listener
.
resources
)
)
_logger
.
warning
(
'resources:
%s'
,
listener
.
resources
)
if
not
listener
.
has_available_resource
():
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
listener
.
on_resource_used
(
1
)
_logger
.
warning
(
'on_resource_used:
{}'
.
format
(
listener
.
resources
)
)
_logger
.
warning
(
'on_resource_used:
%s'
,
listener
.
resources
)
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
(
1
*
num_trials
)
_logger
.
warning
(
'on_resource_available:
{}'
.
format
(
listener
.
resources
)
)
_logger
.
warning
(
'on_resource_available:
%s'
,
listener
.
resources
)
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
...
...
nni/retiarii/execution/cgo_engine.py
View file @
59cd3982
import
logging
import
json
from
typing
import
*
from
typing
import
List
,
Dict
,
Tuple
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
...
...
@@ -12,8 +11,10 @@ from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from
.base
import
BaseGraphData
_logger
=
logging
.
getLogger
(
__name__
)
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
def
__init__
(
self
,
n_model_per_graph
=
4
)
->
None
:
def
__init__
(
self
,
n_model_per_graph
=
4
)
->
None
:
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
...
...
@@ -30,12 +31,11 @@ class CGOExecutionEngine(AbstractExecutionEngine):
advisor
.
intermediate_metric_callback
=
self
.
_intermediate_metric_callback
advisor
.
final_metric_callback
=
self
.
_final_metric_callback
def
add_optimizer
(
self
,
opt
):
self
.
_optimizers
.
append
(
opt
)
def
submit_models
(
self
,
*
models
:
List
[
Model
])
->
None
:
_logger
.
info
(
f
'
{
len
(
models
)
}
M
odels are submitted'
)
_logger
.
info
(
'%d m
odels are submitted'
,
len
(
models
)
)
logical
=
self
.
_build_logical
(
models
)
for
opt
in
self
.
_optimizers
:
...
...
@@ -55,13 +55,13 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
phy_models_and_placements
=
[]
for
multi_model
in
grouped_models
:
model
,
model_placement
=
logical_plan
.
assemble
(
multi_model
)
...
...
@@ -69,7 +69,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
return
phy_models_and_placements
def
_build_logical
(
self
,
models
:
List
[
Model
])
->
LogicalPlan
:
logical_plan
=
LogicalPlan
(
id
=
self
.
logical_plan_counter
)
logical_plan
=
LogicalPlan
(
plan_id
=
self
.
logical_plan_counter
)
for
model
in
models
:
logical_plan
.
add_model
(
model
)
self
.
logical_plan_counter
+=
1
...
...
@@ -108,7 +108,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
#
model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_intermediate_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
...
...
@@ -117,11 +117,10 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
#
model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
...
...
@@ -141,6 +140,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
trainer_instance
=
trainer_cls
(
model_cls
(),
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
class
AssemblePolicy
:
@
staticmethod
def
group
(
logical_plan
):
...
...
@@ -148,4 +148,3 @@ class AssemblePolicy:
for
idx
,
m
in
enumerate
(
logical_plan
.
models
):
group_model
[
m
]
=
PhysicalDevice
(
'server'
,
f
'cuda:
{
idx
}
'
)
return
[
group_model
]
\ No newline at end of file
nni/retiarii/execution/interface.py
View file @
59cd3982
from
abc
import
*
from
typing
import
*
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
NewType
,
List
from
..graph
import
Model
,
MetricData
...
...
nni/retiarii/execution/listener.py
View file @
59cd3982
from
typing
import
*
from
..graph
import
*
from
.interface
import
*
from
..graph
import
Model
,
ModelStatus
from
.interface
import
MetricData
,
AbstractGraphListener
class
DefaultListener
(
AbstractGraphListener
):
...
...
nni/retiarii/execution/logical_optimizer/interface.py
View file @
59cd3982
from
abc
import
*
from
typing
import
*
from
abc
import
ABC
from
.logical_plan
import
LogicalPlan
class
AbstractOptimizer
(
ABC
):
def
__init__
(
self
)
->
None
:
pass
...
...
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
59cd3982
from
nni.retiarii.operation
import
Operation
from
nni.retiarii.graph
import
Model
,
Graph
,
Edge
,
Node
,
Cell
from
typing
import
*
import
logging
from
nni.retiarii.operation
import
_IOPseudoOperation
import
copy
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
class
PhysicalDevice
:
...
...
@@ -108,11 +107,11 @@ class OriginNode(AbstractLogicalNode):
class
LogicalPlan
:
def
__init__
(
self
,
id
=
0
)
->
None
:
def
__init__
(
self
,
plan_
id
=
0
)
->
None
:
self
.
lp_model
=
Model
(
_internal
=
True
)
self
.
id
=
id
self
.
id
=
plan_
id
self
.
logical_graph
=
LogicalGraph
(
self
.
lp_model
,
id
,
name
=
f
'
{
id
}
'
,
_internal
=
True
).
_register
()
self
.
lp_model
,
self
.
id
,
name
=
f
'
{
self
.
id
}
'
,
_internal
=
True
).
_register
()
self
.
lp_model
.
_root_graph_name
=
self
.
logical_graph
.
name
self
.
models
=
[]
...
...
@@ -148,7 +147,7 @@ class LogicalPlan:
phy_model
.
training_config
.
kwargs
[
'is_multi_model'
]
=
True
phy_model
.
training_config
.
kwargs
[
'model_cls'
]
=
phy_graph
.
name
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
]
=
[]
#FIXME: allow user to specify
#
FIXME: allow user to specify
phy_model
.
training_config
.
module
=
'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
...
...
@@ -158,7 +157,6 @@ class LogicalPlan:
model
.
graphs
[
graph_name
].
_fork_to
(
phy_model
,
name_prefix
=
f
'M_
{
model
.
model_id
}
_'
)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot
=
{}
# Model ID -> Slot ID
...
...
@@ -230,7 +228,7 @@ class LogicalPlan:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
...
...
@@ -250,7 +248,6 @@ class LogicalPlan:
edge
.
head_slot
=
input_slot_mapping
[
edge
.
head
]
edge
.
head
=
phy_graph
.
input_node
# merge all output nodes into one with multiple slots
output_nodes
=
[]
for
node
in
phy_graph
.
hidden_nodes
:
...
...
nni/retiarii/execution/logical_optimizer/opt_batching.py
deleted
100644 → 0
View file @
593a275c
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
BatchingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
59cd3982
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
LogicalPlan
,
AbstractLogicalNode
,
LogicalGraph
,
OriginNode
,
PhysicalDevice
from
nni.retiarii
import
Graph
,
Node
,
Model
from
typing
import
*
from
nni.retiarii.operation
import
_IOPseudoOperation
from
typing
import
List
,
Dict
,
Tuple
from
...graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
OriginNode
,
PhysicalDevice
)
_supported_training_modules
=
[
'nni.retiarii.trainer.PyTorchImageClassificationTrainer'
]
class
DedupInputNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
id
:
int
,
\
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
id
,
\
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
\
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
node_
id
:
int
,
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
node_
id
,
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
nodes_to_dedup
[
0
].
operation
)
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
for
node
in
self
.
origin_nodes
:
if
node
.
original_graph
.
model
in
multi_model_placement
:
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
\
f
'M_
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
'
,
\
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
f
'M_
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
'
,
node
.
operation
)
return
new_node
,
multi_model_placement
[
node
.
original_graph
.
model
]
raise
ValueError
(
f
'DedupInputNode
{
self
.
name
}
does not contain nodes from multi_model'
)
...
...
@@ -26,7 +28,6 @@ class DedupInputNode(AbstractLogicalNode):
def
_fork_to
(
self
,
graph
:
Graph
):
DedupInputNode
(
graph
,
self
.
id
,
self
.
origin_nodes
).
_register
()
def
__repr__
(
self
)
->
str
:
return
f
'DedupNode(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
\
len(nodes_to_dedup)=
{
len
(
self
.
origin_nodes
)
}
'
...
...
@@ -35,6 +36,7 @@ class DedupInputNode(AbstractLogicalNode):
class
DedupInputOptimizer
(
AbstractOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
_check_deduplicate_by_node
(
self
,
root_node
,
node_to_check
):
if
root_node
==
node_to_check
:
return
True
...
...
@@ -51,12 +53,11 @@ class DedupInputOptimizer(AbstractOptimizer):
else
:
return
False
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
nodes_to_skip
=
set
()
while
True
:
# repeat until the logical_graph converges
input_nodes
=
logical_plan
.
logical_graph
.
get_nodes_by_type
(
"_inputs"
)
#_PseudoOperation(type_name="_inputs"))
#
_PseudoOperation(type_name="_inputs"))
root_node
=
None
for
node
in
input_nodes
:
if
node
in
nodes_to_skip
:
...
...
@@ -77,7 +78,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
\
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
...
...
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
deleted
100644 → 0
View file @
593a275c
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
WeightSharingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/experiment.py
View file @
59cd3982
import
dataclasses
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
subprocess
import
Popen
from
threading
import
Thread
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
Optional
from
..experiment
import
Experiment
,
TrainingServiceConfig
from
..experiment
import
launcher
,
rest
from
..experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
,
rest
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config
import
util
from
..experiment.pipe
import
Pipe
from
.graph
import
Model
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.converter.graph_gen
import
convert_to_graph
from
.mutator
import
LayerChoiceMutator
,
InputChoiceMutator
from
.converter
import
convert_to_graph
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.trainer.interface
import
BaseTrainer
from
.strategies.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
...
...
@@ -52,6 +56,7 @@ class RetiariiExeConfig(ConfigBase):
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
...
...
@@ -70,8 +75,8 @@ _validation_rules = {
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
'nn.Module'
,
trainer
:
'
BaseTrainer
'
,
applied_mutators
:
List
[
'
Mutator
'
]
,
strategy
:
'
BaseStrategy
'
):
def
__init__
(
self
,
base_model
:
Model
,
trainer
:
BaseTrainer
,
applied_mutators
:
Mutator
,
strategy
:
BaseStrategy
):
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
...
...
nni/retiarii/graph.py
View file @
59cd3982
...
...
@@ -5,7 +5,6 @@ Model representation.
import
copy
from
enum
import
Enum
import
json
from
collections
import
defaultdict
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
...
...
@@ -330,11 +329,11 @@ class Graph:
"""
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
operation
.
type
==
operation_type
]
def
get_node_by_id
(
self
,
id
:
int
)
->
Optional
[
'Node'
]:
def
get_node_by_id
(
self
,
node_
id
:
int
)
->
Optional
[
'Node'
]:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
id
==
id
]
found
=
[
node
for
node
in
self
.
nodes
if
node
.
id
==
node_
id
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
...
...
@@ -365,7 +364,8 @@ class Graph:
curr_nodes
.
append
(
successor
)
for
key
in
node_to_fanin
:
assert
node_to_fanin
[
key
]
==
0
,
'{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'
.
format
(
key
,
assert
node_to_fanin
[
key
]
==
0
,
'{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'
.
format
(
key
,
node_to_fanin
[
key
],
key
.
predecessors
[
0
],
self
.
edges
,
...
...
@@ -587,6 +587,7 @@ class Node:
ret
[
'label'
]
=
self
.
label
return
ret
class
Edge
:
"""
A tensor, or "data flow", between two nodes.
...
...
nni/retiarii/integration.py
View file @
59cd3982
import
logging
import
threading
from
typing
import
*
from
typing
import
Any
,
Callable
import
json_tricks
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
send
,
CommandType
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
.
import
utils
from
.graph
import
MetricData
_logger
=
logging
.
getLogger
(
'nni.msg_dispatcher_base'
)
...
...
@@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
"""
def
__init__
(
self
):
super
(
RetiariiAdvisor
,
self
).
__init__
()
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
...
...
@@ -88,28 +86,28 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters'
:
parameters
,
'parameter_source'
:
'algorithm'
}
_logger
.
info
(
'New trial sent:
{}'
.
format
(
new_trial
)
)
_logger
.
info
(
'New trial sent:
%s'
,
new_trial
)
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
new_trial
))
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs:
{}'
.
format
(
num_trials
)
)
_logger
.
info
(
'Request trial jobs:
%s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
self
.
request_trial_jobs_callback
(
num_trials
)
# pylint: disable=not-callable
def
handle_update_search_space
(
self
,
data
):
_logger
.
info
(
'Received search space:
{}'
.
format
(
data
)
)
_logger
.
info
(
'Received search space:
%s'
,
data
)
self
.
search_space
=
data
def
handle_trial_end
(
self
,
data
):
_logger
.
info
(
'Trial end:
{}'
.
format
(
data
))
# do nothing
_logger
.
info
(
'Trial end:
%s'
,
data
)
self
.
trial_end_callback
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
_logger
.
info
(
'Metric reported:
{}'
.
format
(
data
)
)
_logger
.
info
(
'Metric reported:
%s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
nni/retiarii/mutator.py
View file @
59cd3982
...
...
@@ -13,6 +13,7 @@ class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def
choice
(
self
,
candidates
:
List
[
Choice
],
mutator
:
'Mutator'
,
model
:
Model
,
index
:
int
)
->
Choice
:
raise
NotImplementedError
()
...
...
@@ -35,6 +36,7 @@ class Mutator:
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
"""
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
):
self
.
sampler
:
Optional
[
Sampler
]
=
sampler
self
.
_cur_model
:
Optional
[
Model
]
=
None
...
...
@@ -77,7 +79,6 @@ class Mutator:
self
.
sampler
=
sampler_backup
return
recorder
.
recorded_candidates
,
new_model
def
mutate
(
self
,
model
:
Model
)
->
None
:
"""
Abstract method to be implemented by subclass.
...
...
@@ -105,6 +106,7 @@ class _RecorderSampler(Sampler):
# the following is for inline mutation
class
LayerChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
candidates
:
List
):
super
().
__init__
()
...
...
@@ -118,6 +120,7 @@ class LayerChoiceMutator(Mutator):
chosen_cand
=
self
.
candidates
[
chosen_index
]
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_chosen
:
int
):
super
().
__init__
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment