Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
59cd3982
Unverified
Commit
59cd3982
authored
Dec 14, 2020
by
Yuge Zhang
Committed by
GitHub
Dec 14, 2020
Browse files
[Retiarii] Coding style improvements for pylint and flake8 (#3190)
parent
593a275c
Changes
34
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
148 additions
and
152 deletions
+148
-152
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+10
-9
nni/retiarii/converter/__init__.py
nni/retiarii/converter/__init__.py
+0
-1
nni/retiarii/converter/graph_gen.py
nni/retiarii/converter/graph_gen.py
+29
-18
nni/retiarii/converter/op_types.py
nni/retiarii/converter/op_types.py
+3
-2
nni/retiarii/converter/utils.py
nni/retiarii/converter/utils.py
+1
-0
nni/retiarii/converter/visualize.py
nni/retiarii/converter/visualize.py
+3
-1
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+2
-3
nni/retiarii/execution/base.py
nni/retiarii/execution/base.py
+4
-4
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+16
-17
nni/retiarii/execution/interface.py
nni/retiarii/execution/interface.py
+2
-2
nni/retiarii/execution/listener.py
nni/retiarii/execution/listener.py
+2
-4
nni/retiarii/execution/logical_optimizer/interface.py
nni/retiarii/execution/logical_optimizer/interface.py
+3
-3
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+12
-15
nni/retiarii/execution/logical_optimizer/opt_batching.py
nni/retiarii/execution/logical_optimizer/opt_batching.py
+0
-10
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+26
-25
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
...etiarii/execution/logical_optimizer/opt_weight_sharing.py
+0
-10
nni/retiarii/experiment.py
nni/retiarii/experiment.py
+16
-11
nni/retiarii/graph.py
nni/retiarii/graph.py
+6
-5
nni/retiarii/integration.py
nni/retiarii/integration.py
+8
-10
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+5
-2
No files found.
nni/retiarii/codegen/pytorch.py
View file @
59cd3982
import
logging
from
typing
import
*
from
typing
import
List
from
..graph
import
IllegalGraphError
,
Edge
,
Graph
,
Node
,
Model
from
..operation
import
Operation
,
Cell
_logger
=
logging
.
getLogger
(
__name__
)
def
model_to_pytorch_script
(
model
:
Model
,
placement
=
None
)
->
str
:
def
model_to_pytorch_script
(
model
:
Model
,
placement
=
None
)
->
str
:
graphs
=
[]
total_pkgs
=
set
()
for
name
,
cell
in
model
.
graphs
.
items
():
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
graphs
.
append
(
graph_code
)
total_pkgs
.
update
(
import_pkgs
)
pkgs_code
=
'
\n
'
.
join
([
'import {}'
.
format
(
pkg
)
for
pkg
in
total_pkgs
])
return
_PyTorchScriptTemplate
.
format
(
pkgs_code
,
'
\n\n
'
.
join
(
graphs
)).
strip
()
def
_sorted_incoming_edges
(
node
:
Node
)
->
List
[
Edge
]:
edges
=
[
edge
for
edge
in
node
.
graph
.
edges
if
edge
.
tail
is
node
]
_logger
.
info
(
'sorted_incoming_edges:
{}'
.
format
(
edges
))
_logger
.
info
(
'sorted_incoming_edges:
%s'
,
str
(
edges
))
if
not
edges
:
return
[]
_logger
.
info
(
f
'all tail_slots are None:
{
[
edge
.
tail_slot
for
edge
in
edges
]
}
'
)
_logger
.
info
(
'all tail_slots are None:
%s'
,
str
(
[
edge
.
tail_slot
for
edge
in
edges
]
)
)
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
@@ -32,6 +31,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
return
edges
raise
IllegalGraphError
(
node
.
graph
,
'Node {} has bad inputs'
.
format
(
node
.
name
))
def
_format_inputs
(
node
:
Node
)
->
List
[
str
]:
edges
=
_sorted_incoming_edges
(
node
)
inputs
=
[]
...
...
@@ -53,6 +53,7 @@ def _format_inputs(node: Node) -> List[str]:
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
def
_remove_prefix
(
names
,
graph_name
):
"""
variables name (full name space) is too long,
...
...
@@ -69,14 +70,14 @@ def _remove_prefix(names, graph_name):
else
:
return
names
[
len
(
graph_name
):]
if
names
.
startswith
(
graph_name
)
else
names
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
nodes
=
graph
.
topo_sort
()
# handle module node and function node differently
# only need to generate code for module here
import_pkgs
=
set
()
node_codes
=
[]
placement_codes
=
[]
for
node
in
nodes
:
if
node
.
operation
:
pkg_name
=
node
.
operation
.
get_import_pkg
()
...
...
nni/retiarii/converter/__init__.py
View file @
59cd3982
from
.graph_gen
import
convert_to_graph
from
.visualize
import
visualize_model
\ No newline at end of file
nni/retiarii/converter/graph_gen.py
View file @
59cd3982
import
json_tricks
import
logging
import
re
import
torch
from
..graph
import
Graph
,
Node
,
Edge
,
Model
from
..operation
import
Cell
,
Operation
from
..nn.pytorch
import
Placeholder
,
LayerChoice
,
InputChoice
import
torch
from
.op_types
import
MODULE_EXCEPT_LIST
,
OpTypeName
,
BasicOpsPT
from
.utils
import
build_full_name
,
_convert_name
from
..graph
import
Graph
,
Model
,
Node
from
..nn.pytorch
import
InputChoice
,
LayerChoice
,
Placeholder
from
..operation
import
Cell
from
.op_types
import
MODULE_EXCEPT_LIST
,
BasicOpsPT
,
OpTypeName
from
.utils
import
_convert_name
,
build_full_name
_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -16,6 +15,7 @@ global_seq = 0
global_graph_id
=
0
modules_arg
=
None
def
_add_edge
(
ir_graph
,
node
,
graph_inputs
,
node_index
,
new_node
,
output_remap
,
ignore_first
=
False
):
"""
Parameters
...
...
@@ -76,6 +76,7 @@ def _add_edge(ir_graph, node, graph_inputs, node_index, new_node, output_remap,
new_node_input_idx
+=
1
def
create_prim_constant_node
(
ir_graph
,
node
,
module_name
):
global
global_seq
attrs
=
{}
...
...
@@ -86,14 +87,17 @@ def create_prim_constant_node(ir_graph, node, module_name):
node
.
kind
(),
attrs
)
return
new_node
def
handle_prim_attr_node
(
node
):
assert
node
.
hasAttribute
(
'name'
)
attrs
=
{
'name'
:
node
.
s
(
'name'
),
'input'
:
node
.
inputsAt
(
0
).
debugName
()}
return
node
.
kind
(),
attrs
def
_remove_mangle
(
module_type_str
):
return
re
.
sub
(
'
\\
.___torch_mangle_
\\
d+'
,
''
,
module_type_str
)
def
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
None
):
"""
Parameters
...
...
@@ -122,6 +126,7 @@ def remove_unconnected_nodes(ir_graph, targeted_type=None):
for
hidden_node
in
to_removes
:
hidden_node
.
remove
()
def
handle_graph_nodes
(
script_module
,
sm_graph
,
module
,
module_name
,
ir_model
,
ir_graph
):
"""
Convert torch script node to our node ir, and build our graph ir
...
...
@@ -248,7 +253,8 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
# therefore, we do this check for a module. example below:
# %25 : __torch__.xxx = prim::GetAttr[name="input_switch"](%self)
# %27 : Tensor = prim::CallMethod[name="forward"](%25, %out.1)
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
assert
submodule_name
in
script_module
.
_modules
,
"submodule_name: {} not in script_module {}"
.
format
(
submodule_name
,
script_module
.
_modules
.
keys
())
submodule_full_name
=
build_full_name
(
module_name
,
submodule_name
)
submodule_obj
=
getattr
(
module
,
submodule_name
)
...
...
@@ -350,6 +356,7 @@ def handle_graph_nodes(script_module, sm_graph, module, module_name, ir_model, i
return
node_index
def
merge_aten_slices
(
ir_graph
):
"""
if there is aten::slice node, merge the consecutive ones together.
...
...
@@ -408,13 +415,14 @@ def refine_graph(ir_graph):
remove_unconnected_nodes
(
ir_graph
,
targeted_type
=
'prim::GetAttr'
)
merge_aten_slices
(
ir_graph
)
def
_handle_layerchoice
(
module
):
global
modules_arg
m_attrs
=
{}
candidates
=
module
.
candidate_ops
choices
=
[]
for
i
,
cand
in
enumerate
(
candidates
)
:
for
cand
in
candidates
:
assert
id
(
cand
)
in
modules_arg
,
'id not exist: {}'
.
format
(
id
(
cand
))
assert
isinstance
(
modules_arg
[
id
(
cand
)],
dict
)
cand_type
=
'__torch__.'
+
cand
.
__class__
.
__module__
+
'.'
+
cand
.
__class__
.
__name__
...
...
@@ -423,6 +431,7 @@ def _handle_layerchoice(module):
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
_handle_inputchoice
(
module
):
m_attrs
=
{}
m_attrs
[
'n_chosen'
]
=
module
.
n_chosen
...
...
@@ -430,6 +439,7 @@ def _handle_inputchoice(module):
m_attrs
[
'label'
]
=
module
.
label
return
m_attrs
def
convert_module
(
script_module
,
module
,
module_name
,
ir_model
):
"""
Convert a module to its graph ir (i.e., Graph) along with its input arguments
...
...
@@ -507,6 +517,7 @@ def convert_module(script_module, module, module_name, ir_model):
# should not be parsed further.
return
ir_graph
,
{}
def
convert_to_graph
(
script_module
,
module
,
recorded_modules_arg
):
"""
Convert module to our graph ir, i.e., build a ```Model``` type
...
...
nni/retiarii/converter/op_types.py
View file @
59cd3982
...
...
@@ -16,6 +16,7 @@ class OpTypeName(str, Enum):
Placeholder
=
'Placeholder'
MergedSlice
=
'MergedSlice'
# deal with aten op
BasicOpsPT
=
{
'aten::mean'
:
'Mean'
,
...
...
nni/retiarii/converter/utils.py
View file @
59cd3982
...
...
@@ -6,6 +6,7 @@ def build_full_name(prefix, name, seq=None):
else
:
return
'{}__{}{}'
.
format
(
prefix
,
name
,
str
(
seq
))
def
_convert_name
(
name
:
str
)
->
str
:
"""
Convert the names using separator '.' to valid variable name in code
...
...
nni/retiarii/converter/visualize.py
View file @
59cd3982
import
graphviz
def
convert_to_visualize
(
graph_ir
,
vgraph
):
for
name
,
graph
in
graph_ir
.
items
():
if
name
==
'_training_config'
:
...
...
@@ -33,6 +34,7 @@ def convert_to_visualize(graph_ir, vgraph):
dst
=
cell_node
[
dst
][
0
]
subgraph
.
edge
(
src
,
dst
)
def
visualize_model
(
graph_ir
):
vgraph
=
graphviz
.
Digraph
(
'G'
,
filename
=
'vgraph'
,
format
=
'jpg'
)
convert_to_visualize
(
graph_ir
,
vgraph
)
...
...
nni/retiarii/execution/api.py
View file @
59cd3982
import
time
import
os
import
importlib.util
from
typing
import
*
from
typing
import
List
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
*
from
.interface
import
AbstractExecutionEngine
,
WorkerInfo
from
.listener
import
DefaultListener
_execution_engine
=
None
...
...
nni/retiarii/execution/base.py
View file @
59cd3982
import
logging
from
typing
import
*
from
typing
import
Dict
,
Any
,
List
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
...
...
@@ -61,16 +61,16 @@ class BaseExecutionEngine(AbstractExecutionEngine):
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
_logger
.
warning
(
'resources:
{}'
.
format
(
listener
.
resources
)
)
_logger
.
warning
(
'resources:
%s'
,
listener
.
resources
)
if
not
listener
.
has_available_resource
():
_logger
.
warning
(
'There is no available resource, but trial is submitted.'
)
listener
.
on_resource_used
(
1
)
_logger
.
warning
(
'on_resource_used:
{}'
.
format
(
listener
.
resources
)
)
_logger
.
warning
(
'on_resource_used:
%s'
,
listener
.
resources
)
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
(
1
*
num_trials
)
_logger
.
warning
(
'on_resource_available:
{}'
.
format
(
listener
.
resources
)
)
_logger
.
warning
(
'on_resource_available:
%s'
,
listener
.
resources
)
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
...
...
nni/retiarii/execution/cgo_engine.py
View file @
59cd3982
import
logging
import
json
from
typing
import
*
from
typing
import
List
,
Dict
,
Tuple
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
...
...
@@ -12,8 +11,10 @@ from .logical_optimizer.opt_dedup_input import DedupInputOptimizer
from
.base
import
BaseGraphData
_logger
=
logging
.
getLogger
(
__name__
)
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
def
__init__
(
self
,
n_model_per_graph
=
4
)
->
None
:
def
__init__
(
self
,
n_model_per_graph
=
4
)
->
None
:
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
...
...
@@ -30,12 +31,11 @@ class CGOExecutionEngine(AbstractExecutionEngine):
advisor
.
intermediate_metric_callback
=
self
.
_intermediate_metric_callback
advisor
.
final_metric_callback
=
self
.
_final_metric_callback
def
add_optimizer
(
self
,
opt
):
self
.
_optimizers
.
append
(
opt
)
def
submit_models
(
self
,
*
models
:
List
[
Model
])
->
None
:
_logger
.
info
(
f
'
{
len
(
models
)
}
M
odels are submitted'
)
_logger
.
info
(
'%d m
odels are submitted'
,
len
(
models
)
)
logical
=
self
.
_build_logical
(
models
)
for
opt
in
self
.
_optimizers
:
...
...
@@ -55,13 +55,13 @@ class CGOExecutionEngine(AbstractExecutionEngine):
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
phy_models_and_placements
=
[]
for
multi_model
in
grouped_models
:
model
,
model_placement
=
logical_plan
.
assemble
(
multi_model
)
...
...
@@ -69,7 +69,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
return
phy_models_and_placements
def
_build_logical
(
self
,
models
:
List
[
Model
])
->
LogicalPlan
:
logical_plan
=
LogicalPlan
(
id
=
self
.
logical_plan_counter
)
logical_plan
=
LogicalPlan
(
plan_id
=
self
.
logical_plan_counter
)
for
model
in
models
:
logical_plan
.
add_model
(
model
)
self
.
logical_plan_counter
+=
1
...
...
@@ -108,7 +108,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
#
model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_intermediate_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
...
...
@@ -117,11 +117,10 @@ class CGOExecutionEngine(AbstractExecutionEngine):
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
#
model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
...
...
@@ -141,6 +140,7 @@ class CGOExecutionEngine(AbstractExecutionEngine):
trainer_instance
=
trainer_cls
(
model_cls
(),
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
class
AssemblePolicy
:
@
staticmethod
def
group
(
logical_plan
):
...
...
@@ -148,4 +148,3 @@ class AssemblePolicy:
for
idx
,
m
in
enumerate
(
logical_plan
.
models
):
group_model
[
m
]
=
PhysicalDevice
(
'server'
,
f
'cuda:
{
idx
}
'
)
return
[
group_model
]
\ No newline at end of file
nni/retiarii/execution/interface.py
View file @
59cd3982
from
abc
import
*
from
typing
import
*
from
abc
import
ABC
,
abstractmethod
,
abstractclassmethod
from
typing
import
Any
,
NewType
,
List
from
..graph
import
Model
,
MetricData
...
...
nni/retiarii/execution/listener.py
View file @
59cd3982
from
typing
import
*
from
..graph
import
*
from
.interface
import
*
from
..graph
import
Model
,
ModelStatus
from
.interface
import
MetricData
,
AbstractGraphListener
class
DefaultListener
(
AbstractGraphListener
):
...
...
nni/retiarii/execution/logical_optimizer/interface.py
View file @
59cd3982
from
abc
import
*
from
typing
import
*
from
abc
import
ABC
from
.logical_plan
import
LogicalPlan
class
AbstractOptimizer
(
ABC
):
def
__init__
(
self
)
->
None
:
pass
...
...
nni/retiarii/execution/logical_optimizer/logical_plan.py
View file @
59cd3982
from
nni.retiarii.operation
import
Operation
from
nni.retiarii.graph
import
Model
,
Graph
,
Edge
,
Node
,
Cell
from
typing
import
*
import
logging
from
nni.retiarii.operation
import
_IOPseudoOperation
import
copy
from
typing
import
Dict
,
Tuple
,
List
,
Any
from
...graph
import
Cell
,
Edge
,
Graph
,
Model
,
Node
from
...operation
import
Operation
,
_IOPseudoOperation
class
PhysicalDevice
:
...
...
@@ -108,11 +107,11 @@ class OriginNode(AbstractLogicalNode):
class
LogicalPlan
:
def
__init__
(
self
,
id
=
0
)
->
None
:
def
__init__
(
self
,
plan_
id
=
0
)
->
None
:
self
.
lp_model
=
Model
(
_internal
=
True
)
self
.
id
=
id
self
.
id
=
plan_
id
self
.
logical_graph
=
LogicalGraph
(
self
.
lp_model
,
id
,
name
=
f
'
{
id
}
'
,
_internal
=
True
).
_register
()
self
.
lp_model
,
self
.
id
,
name
=
f
'
{
self
.
id
}
'
,
_internal
=
True
).
_register
()
self
.
lp_model
.
_root_graph_name
=
self
.
logical_graph
.
name
self
.
models
=
[]
...
...
@@ -148,7 +147,7 @@ class LogicalPlan:
phy_model
.
training_config
.
kwargs
[
'is_multi_model'
]
=
True
phy_model
.
training_config
.
kwargs
[
'model_cls'
]
=
phy_graph
.
name
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
]
=
[]
#FIXME: allow user to specify
#
FIXME: allow user to specify
phy_model
.
training_config
.
module
=
'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
...
...
@@ -158,7 +157,6 @@ class LogicalPlan:
model
.
graphs
[
graph_name
].
_fork_to
(
phy_model
,
name_prefix
=
f
'M_
{
model
.
model_id
}
_'
)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot
=
{}
# Model ID -> Slot ID
...
...
@@ -230,7 +228,7 @@ class LogicalPlan:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
...
...
@@ -250,7 +248,6 @@ class LogicalPlan:
edge
.
head_slot
=
input_slot_mapping
[
edge
.
head
]
edge
.
head
=
phy_graph
.
input_node
# merge all output nodes into one with multiple slots
output_nodes
=
[]
for
node
in
phy_graph
.
hidden_nodes
:
...
...
nni/retiarii/execution/logical_optimizer/opt_batching.py
deleted
100644 → 0
View file @
593a275c
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
BatchingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
View file @
59cd3982
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
LogicalPlan
,
AbstractLogicalNode
,
LogicalGraph
,
OriginNode
,
PhysicalDevice
from
nni.retiarii
import
Graph
,
Node
,
Model
from
typing
import
*
from
nni.retiarii.operation
import
_IOPseudoOperation
from
typing
import
List
,
Dict
,
Tuple
from
...graph
import
Graph
,
Model
,
Node
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
(
AbstractLogicalNode
,
LogicalGraph
,
LogicalPlan
,
OriginNode
,
PhysicalDevice
)
_supported_training_modules
=
[
'nni.retiarii.trainer.PyTorchImageClassificationTrainer'
]
class
DedupInputNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
id
:
int
,
\
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
id
,
\
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
\
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
node_
id
:
int
,
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
node_
id
,
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
nodes_to_dedup
[
0
].
operation
)
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
for
node
in
self
.
origin_nodes
:
if
node
.
original_graph
.
model
in
multi_model_placement
:
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
\
f
'M_
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
'
,
\
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
f
'M_
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
'
,
node
.
operation
)
return
new_node
,
multi_model_placement
[
node
.
original_graph
.
model
]
raise
ValueError
(
f
'DedupInputNode
{
self
.
name
}
does not contain nodes from multi_model'
)
...
...
@@ -26,7 +28,6 @@ class DedupInputNode(AbstractLogicalNode):
def
_fork_to
(
self
,
graph
:
Graph
):
DedupInputNode
(
graph
,
self
.
id
,
self
.
origin_nodes
).
_register
()
def
__repr__
(
self
)
->
str
:
return
f
'DedupNode(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
\
len(nodes_to_dedup)=
{
len
(
self
.
origin_nodes
)
}
'
...
...
@@ -35,6 +36,7 @@ class DedupInputNode(AbstractLogicalNode):
class
DedupInputOptimizer
(
AbstractOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
_check_deduplicate_by_node
(
self
,
root_node
,
node_to_check
):
if
root_node
==
node_to_check
:
return
True
...
...
@@ -51,12 +53,11 @@ class DedupInputOptimizer(AbstractOptimizer):
else
:
return
False
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
nodes_to_skip
=
set
()
while
True
:
# repeat until the logical_graph converges
input_nodes
=
logical_plan
.
logical_graph
.
get_nodes_by_type
(
"_inputs"
)
#_PseudoOperation(type_name="_inputs"))
#
_PseudoOperation(type_name="_inputs"))
root_node
=
None
for
node
in
input_nodes
:
if
node
in
nodes_to_skip
:
...
...
@@ -77,7 +78,7 @@ class DedupInputOptimizer(AbstractOptimizer):
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
\
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
...
...
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
deleted
100644 → 0
View file @
593a275c
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
WeightSharingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/experiment.py
View file @
59cd3982
import
dataclasses
import
logging
import
time
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
subprocess
import
Popen
from
threading
import
Thread
from
typing
import
Any
,
List
,
Optional
from
typing
import
Any
,
Optional
from
..experiment
import
Experiment
,
TrainingServiceConfig
from
..experiment
import
launcher
,
rest
from
..experiment
import
Experiment
,
TrainingServiceConfig
,
launcher
,
rest
from
..experiment.config.base
import
ConfigBase
,
PathLike
from
..experiment.config
import
util
from
..experiment.pipe
import
Pipe
from
.graph
import
Model
from
.utils
import
get_records
from
.integration
import
RetiariiAdvisor
from
.converter.graph_gen
import
convert_to_graph
from
.mutator
import
LayerChoiceMutator
,
InputChoiceMutator
from
.converter
import
convert_to_graph
from
.mutator
import
Mutator
,
LayerChoiceMutator
,
InputChoiceMutator
from
.trainer.interface
import
BaseTrainer
from
.strategies.strategy
import
BaseStrategy
_logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
(
init
=
False
)
class
RetiariiExeConfig
(
ConfigBase
):
experiment_name
:
Optional
[
str
]
=
None
...
...
@@ -52,6 +56,7 @@ class RetiariiExeConfig(ConfigBase):
def
_validation_rules
(
self
):
return
_validation_rules
_canonical_rules
=
{
'trial_code_directory'
:
util
.
canonical_path
,
'max_experiment_duration'
:
lambda
value
:
f
'
{
util
.
parse_time
(
value
)
}
s'
if
value
is
not
None
else
None
,
...
...
@@ -70,8 +75,8 @@ _validation_rules = {
class
RetiariiExperiment
(
Experiment
):
def
__init__
(
self
,
base_model
:
'nn.Module'
,
trainer
:
'
BaseTrainer
'
,
applied_mutators
:
List
[
'
Mutator
'
]
,
strategy
:
'
BaseStrategy
'
):
def
__init__
(
self
,
base_model
:
Model
,
trainer
:
BaseTrainer
,
applied_mutators
:
Mutator
,
strategy
:
BaseStrategy
):
self
.
config
:
RetiariiExeConfig
=
None
self
.
port
:
Optional
[
int
]
=
None
...
...
nni/retiarii/graph.py
View file @
59cd3982
...
...
@@ -5,7 +5,6 @@ Model representation.
import
copy
from
enum
import
Enum
import
json
from
collections
import
defaultdict
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
...
...
@@ -330,11 +329,11 @@ class Graph:
"""
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
operation
.
type
==
operation_type
]
def
get_node_by_id
(
self
,
id
:
int
)
->
Optional
[
'Node'
]:
def
get_node_by_id
(
self
,
node_
id
:
int
)
->
Optional
[
'Node'
]:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
id
==
id
]
found
=
[
node
for
node
in
self
.
nodes
if
node
.
id
==
node_
id
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
...
...
@@ -365,7 +364,8 @@ class Graph:
curr_nodes
.
append
(
successor
)
for
key
in
node_to_fanin
:
assert
node_to_fanin
[
key
]
==
0
,
'{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'
.
format
(
key
,
assert
node_to_fanin
[
key
]
==
0
,
'{}, fanin: {}, predecessor: {}, edges: {}, fanin: {}, keys: {}'
.
format
(
key
,
node_to_fanin
[
key
],
key
.
predecessors
[
0
],
self
.
edges
,
...
...
@@ -587,6 +587,7 @@ class Node:
ret
[
'label'
]
=
self
.
label
return
ret
class
Edge
:
"""
A tensor, or "data flow", between two nodes.
...
...
nni/retiarii/integration.py
View file @
59cd3982
import
logging
import
threading
from
typing
import
*
from
typing
import
Any
,
Callable
import
json_tricks
import
nni
from
nni.runtime.msg_dispatcher_base
import
MsgDispatcherBase
from
nni.runtime.protocol
import
send
,
CommandType
from
nni.runtime.protocol
import
CommandType
,
send
from
nni.utils
import
MetricType
from
.
import
utils
from
.graph
import
MetricData
_logger
=
logging
.
getLogger
(
'nni.msg_dispatcher_base'
)
...
...
@@ -44,6 +41,7 @@ class RetiariiAdvisor(MsgDispatcherBase):
final_metric_callback
"""
def
__init__
(
self
):
super
(
RetiariiAdvisor
,
self
).
__init__
()
register_advisor
(
self
)
# register the current advisor as the "global only" advisor
...
...
@@ -88,28 +86,28 @@ class RetiariiAdvisor(MsgDispatcherBase):
'parameters'
:
parameters
,
'parameter_source'
:
'algorithm'
}
_logger
.
info
(
'New trial sent:
{}'
.
format
(
new_trial
)
)
_logger
.
info
(
'New trial sent:
%s'
,
new_trial
)
send
(
CommandType
.
NewTrialJob
,
json_tricks
.
dumps
(
new_trial
))
if
self
.
send_trial_callback
is
not
None
:
self
.
send_trial_callback
(
parameters
)
# pylint: disable=not-callable
return
self
.
parameters_count
def
handle_request_trial_jobs
(
self
,
num_trials
):
_logger
.
info
(
'Request trial jobs:
{}'
.
format
(
num_trials
)
)
_logger
.
info
(
'Request trial jobs:
%s'
,
num_trials
)
if
self
.
request_trial_jobs_callback
is
not
None
:
self
.
request_trial_jobs_callback
(
num_trials
)
# pylint: disable=not-callable
def
handle_update_search_space
(
self
,
data
):
_logger
.
info
(
'Received search space:
{}'
.
format
(
data
)
)
_logger
.
info
(
'Received search space:
%s'
,
data
)
self
.
search_space
=
data
def
handle_trial_end
(
self
,
data
):
_logger
.
info
(
'Trial end:
{}'
.
format
(
data
))
# do nothing
_logger
.
info
(
'Trial end:
%s'
,
data
)
self
.
trial_end_callback
(
json_tricks
.
loads
(
data
[
'hyper_params'
])[
'parameter_id'
],
# pylint: disable=not-callable
data
[
'event'
]
==
'SUCCEEDED'
)
def
handle_report_metric_data
(
self
,
data
):
_logger
.
info
(
'Metric reported:
{}'
.
format
(
data
)
)
_logger
.
info
(
'Metric reported:
%s'
,
data
)
if
data
[
'type'
]
==
MetricType
.
REQUEST_PARAMETER
:
raise
ValueError
(
'Request parameter not supported'
)
elif
data
[
'type'
]
==
MetricType
.
PERIODICAL
:
...
...
nni/retiarii/mutator.py
View file @
59cd3982
...
...
@@ -13,6 +13,7 @@ class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def
choice
(
self
,
candidates
:
List
[
Choice
],
mutator
:
'Mutator'
,
model
:
Model
,
index
:
int
)
->
Choice
:
raise
NotImplementedError
()
...
...
@@ -35,6 +36,7 @@ class Mutator:
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
"""
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
):
self
.
sampler
:
Optional
[
Sampler
]
=
sampler
self
.
_cur_model
:
Optional
[
Model
]
=
None
...
...
@@ -77,7 +79,6 @@ class Mutator:
self
.
sampler
=
sampler_backup
return
recorder
.
recorded_candidates
,
new_model
def
mutate
(
self
,
model
:
Model
)
->
None
:
"""
Abstract method to be implemented by subclass.
...
...
@@ -105,6 +106,7 @@ class _RecorderSampler(Sampler):
# the following is for inline mutation
class
LayerChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
candidates
:
List
):
super
().
__init__
()
...
...
@@ -118,6 +120,7 @@ class LayerChoiceMutator(Mutator):
chosen_cand
=
self
.
candidates
[
chosen_index
]
target
.
update_operation
(
chosen_cand
[
'type'
],
chosen_cand
[
'parameters'
])
class
InputChoiceMutator
(
Mutator
):
def
__init__
(
self
,
node_name
:
str
,
n_chosen
:
int
):
super
().
__init__
()
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment