Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7d1acfbd
Unverified
Commit
7d1acfbd
authored
Dec 11, 2020
by
Zhenhua Han
Committed by
GitHub
Dec 11, 2020
Browse files
[Retiarii] cross-graph optimization: input deduplication (#3105)
parent
165756cc
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
767 additions
and
30 deletions
+767
-30
.gitignore
.gitignore
+3
-0
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+2
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+11
-5
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+6
-1
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+151
-0
nni/retiarii/execution/logical_optimizer/__init__.py
nni/retiarii/execution/logical_optimizer/__init__.py
+0
-0
nni/retiarii/execution/logical_optimizer/interface.py
nni/retiarii/execution/logical_optimizer/interface.py
+11
-0
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+297
-0
nni/retiarii/execution/logical_optimizer/opt_batching.py
nni/retiarii/execution/logical_optimizer/opt_batching.py
+10
-0
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+88
-0
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
...etiarii/execution/logical_optimizer/opt_weight_sharing.py
+10
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+15
-4
nni/retiarii/integration.py
nni/retiarii/integration.py
+4
-1
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+1
-7
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-1
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+1
-1
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+16
-0
nni/retiarii/trainer/__init__.py
nni/retiarii/trainer/__init__.py
+1
-1
nni/retiarii/trainer/pytorch/__init__.py
nni/retiarii/trainer/pytorch/__init__.py
+1
-1
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+134
-7
No files found.
.gitignore
View file @
7d1acfbd
...
@@ -97,3 +97,6 @@ venv.bak/
...
@@ -97,3 +97,6 @@ venv.bak/
# VSCode
# VSCode
.vscode
.vscode
.vs
.vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json
nni/retiarii/__init__.py
View file @
7d1acfbd
from
.
execu
tion
import
*
from
.
opera
tion
import
Operation
from
.graph
import
*
from
.graph
import
*
from
.execution
import
*
from
.mutator
import
*
from
.mutator
import
*
from
.model_apis
import
nn
from
.model_apis
import
nn
nni/retiarii/codegen/pytorch.py
View file @
7d1acfbd
...
@@ -7,11 +7,12 @@ from ..operation import Operation, Cell
...
@@ -7,11 +7,12 @@ from ..operation import Operation, Cell
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
def
model_to_pytorch_script
(
model
:
Model
)
->
str
:
def
model_to_pytorch_script
(
model
:
Model
,
placement
=
None
)
->
str
:
graphs
=
[]
graphs
=
[]
total_pkgs
=
set
()
total_pkgs
=
set
()
for
name
,
cell
in
model
.
graphs
.
items
():
for
name
,
cell
in
model
.
graphs
.
items
():
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
)
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
graphs
.
append
(
graph_code
)
graphs
.
append
(
graph_code
)
total_pkgs
.
update
(
import_pkgs
)
total_pkgs
.
update
(
import_pkgs
)
# TODO: set correct PATH for the packages (after launch refactor)
# TODO: set correct PATH for the packages (after launch refactor)
...
@@ -23,6 +24,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
...
@@ -23,6 +24,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
_logger
.
info
(
'sorted_incoming_edges: {}'
.
format
(
edges
))
_logger
.
info
(
'sorted_incoming_edges: {}'
.
format
(
edges
))
if
not
edges
:
if
not
edges
:
return
[]
return
[]
_logger
.
info
(
f
'all tail_slots are None:
{
[
edge
.
tail_slot
for
edge
in
edges
]
}
'
)
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
@@ -52,13 +54,14 @@ def _format_inputs(node: Node) -> List[str]:
...
@@ -52,13 +54,14 @@ def _format_inputs(node: Node) -> List[str]:
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
return
inputs
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
)
->
str
:
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
nodes
=
graph
.
nodes
nodes
=
graph
.
topo_sort
()
# FIXME: topological sort is needed here
# handle module node and function node differently
# handle module node and function node differently
# only need to generate code for module here
# only need to generate code for module here
import_pkgs
=
set
()
import_pkgs
=
set
()
node_codes
=
[]
node_codes
=
[]
placement_codes
=
[]
for
node
in
nodes
:
for
node
in
nodes
:
if
node
.
operation
:
if
node
.
operation
:
pkg_name
=
node
.
operation
.
get_import_pkg
()
pkg_name
=
node
.
operation
.
get_import_pkg
()
...
@@ -66,6 +69,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
...
@@ -66,6 +69,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
import_pkgs
.
add
(
pkg_name
)
import_pkgs
.
add
(
pkg_name
)
node_code
=
node
.
operation
.
to_init_code
(
node
.
name
)
node_code
=
node
.
operation
.
to_init_code
(
node
.
name
)
if
node_code
is
not
None
:
if
node_code
is
not
None
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
node_codes
.
append
(
f
"
{
node_code
}
.to('
{
placement
[
node
].
device
}
')"
)
else
:
node_codes
.
append
(
node_code
)
node_codes
.
append
(
node_code
)
if
graph
.
input_node
.
operation
.
io_names
is
None
:
if
graph
.
input_node
.
operation
.
io_names
is
None
:
...
...
nni/retiarii/execution/api.py
View file @
7d1acfbd
import
time
import
time
import
os
import
importlib.util
import
importlib.util
from
typing
import
*
from
typing
import
*
from
..graph
import
Model
,
ModelStatus
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.base
import
BaseExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
*
from
.interface
import
*
from
.listener
import
DefaultListener
from
.listener
import
DefaultListener
...
@@ -21,6 +23,9 @@ def get_execution_engine() -> BaseExecutionEngine:
...
@@ -21,6 +23,9 @@ def get_execution_engine() -> BaseExecutionEngine:
"""
"""
global
_execution_engine
global
_execution_engine
if
_execution_engine
is
None
:
if
_execution_engine
is
None
:
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
_execution_engine
=
CGOExecutionEngine
()
else
:
_execution_engine
=
BaseExecutionEngine
()
_execution_engine
=
BaseExecutionEngine
()
return
_execution_engine
return
_execution_engine
...
...
nni/retiarii/execution/cgo_engine.py
0 → 100644
View file @
7d1acfbd
import
logging
import
json
from
typing
import
*
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevice
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
from
.base
import
BaseGraphData
_logger
=
logging
.
getLogger
(
__name__
)
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
def
__init__
(
self
,
n_model_per_graph
=
4
)
->
None
:
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
self
.
n_model_per_graph
=
n_model_per_graph
self
.
_optimizers
=
[
DedupInputOptimizer
()]
self
.
_original_models
=
{}
self
.
_original_model_to_multi_model
=
{}
# register advisor callbacks
advisor
=
get_advisor
()
advisor
.
send_trial_callback
=
self
.
_send_trial_callback
advisor
.
request_trial_jobs_callback
=
self
.
_request_trial_jobs_callback
advisor
.
trial_end_callback
=
self
.
_trial_end_callback
advisor
.
intermediate_metric_callback
=
self
.
_intermediate_metric_callback
advisor
.
final_metric_callback
=
self
.
_final_metric_callback
def
add_optimizer
(
self
,
opt
):
self
.
_optimizers
.
append
(
opt
)
def
submit_models
(
self
,
*
models
:
List
[
Model
])
->
None
:
_logger
.
info
(
f
'
{
len
(
models
)
}
Models are submitted'
)
logical
=
self
.
_build_logical
(
models
)
for
opt
in
self
.
_optimizers
:
opt
.
convert
(
logical
)
phy_models_and_placements
=
self
.
_assemble
(
logical
)
for
model
,
placement
,
grouped_models
in
phy_models_and_placements
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
training_config
.
module
,
model
.
training_config
.
kwargs
)
for
m
in
grouped_models
:
self
.
_original_models
[
m
.
model_id
]
=
m
self
.
_original_model_to_multi_model
[
m
.
model_id
]
=
model
self
.
_running_models
[
send_trial
(
data
.
dump
())]
=
model
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
phy_models_and_placements
=
[]
for
multi_model
in
grouped_models
:
model
,
model_placement
=
logical_plan
.
assemble
(
multi_model
)
phy_models_and_placements
.
append
((
model
,
model_placement
,
multi_model
.
keys
()))
return
phy_models_and_placements
def
_build_logical
(
self
,
models
:
List
[
Model
])
->
LogicalPlan
:
logical_plan
=
LogicalPlan
(
id
=
self
.
logical_plan_counter
)
for
model
in
models
:
logical_plan
.
add_model
(
model
)
self
.
logical_plan_counter
+=
1
return
logical_plan
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
self
.
_listeners
.
append
(
listener
)
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_used
(
0
)
# FIXME: find the real resource id
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
([
0
]
*
num_trials
)
# FIXME: find the real resource id
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
if
success
:
model
.
status
=
ModelStatus
.
Trained
else
:
model
.
status
=
ModelStatus
.
Failed
for
model_id
in
self
.
_original_model_to_multi_model
:
if
self
.
_original_model_to_multi_model
[
model_id
]
==
model
:
original_model
=
self
.
_original_models
[
model_id
]
if
success
:
original_model
.
status
=
ModelStatus
.
Trained
else
:
original_model
.
status
=
ModelStatus
.
Failed
for
listener
in
self
.
_listeners
:
listener
.
on_training_end
(
original_model
,
success
)
def
_intermediate_metric_callback
(
self
,
trial_id
:
int
,
metrics
:
MetricData
)
->
None
:
# model = self._running_models[trial_id]
merged_metrics
=
dict
(
metrics
)
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_intermediate_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
def
_final_metric_callback
(
self
,
trial_id
:
int
,
metrics
:
MetricData
)
->
None
:
merged_metrics
=
dict
(
metrics
)
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
"""
Initialize the model, hand it over to trainer.
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
_logger
.
info
(
'CGO_ENGINE trial parameters received'
)
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
# with open('_debug_graph_data.json', 'w') as f:
# json.dump(graph_data.dump(), f)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
model_cls
=
utils
.
import_
(
f
"_generated_model.
{
graph_data
.
training_kwargs
[
'model_cls'
]
}
"
)
trainer_instance
=
trainer_cls
(
model_cls
(),
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
class
AssemblePolicy
:
@
staticmethod
def
group
(
logical_plan
):
group_model
=
{}
for
idx
,
m
in
enumerate
(
logical_plan
.
models
):
group_model
[
m
]
=
PhysicalDevice
(
'server'
,
f
'cuda:
{
idx
}
'
)
return
[
group_model
]
\ No newline at end of file
nni/retiarii/execution/logical_optimizer/__init__.py
0 → 100644
View file @
7d1acfbd
nni/retiarii/execution/logical_optimizer/interface.py
0 → 100644
View file @
7d1acfbd
from
abc
import
*
from
typing
import
*
from
.logical_plan
import
LogicalPlan
class
AbstractOptimizer
(
ABC
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
raise
NotImplementedError
nni/retiarii/execution/logical_optimizer/logical_plan.py
0 → 100644
View file @
7d1acfbd
from
nni.retiarii.operation
import
Operation
from
nni.retiarii.graph
import
Model
,
Graph
,
Edge
,
Node
,
Cell
from
typing
import
*
import
logging
from
nni.retiarii.operation
import
_IOPseudoOperation
import
copy
class
PhysicalDevice
:
def
__init__
(
self
,
server
:
str
,
device
:
str
):
self
.
server
=
server
self
.
device
=
device
def
__eq__
(
self
,
o
)
->
bool
:
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
class
AbstractLogicalNode
(
Node
):
def
__init__
(
self
,
graph
,
node_id
,
name
,
operation
,
_internal
=
False
):
super
().
__init__
(
graph
,
node_id
,
name
,
operation
,
_internal
=
_internal
)
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
raise
NotImplementedError
def
_fork_to
(
self
,
graph
:
Graph
):
raise
NotImplementedError
class
LogicalGraph
(
Graph
):
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
super
().
__init__
(
model
,
graph_id
,
name
=
'logical_'
+
name
,
_internal
=
_internal
)
def
_dump
(
self
)
->
Any
:
nodes_dump
=
{}
for
node
in
self
.
hidden_nodes
:
if
isinstance
(
node
,
OriginNode
):
nodes_dump
[
f
"
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
(
)
else
:
nodes_dump
[
f
"
{
node
.
graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
()
edges_dump
=
[]
for
edge
in
self
.
edges
:
if
isinstance
(
edge
.
head
,
OriginNode
):
head_info
=
f
'
{
edge
.
head
.
original_graph
.
model
.
model_id
}
_
{
edge
.
head
.
name
}
'
else
:
head_info
=
edge
.
head
.
name
if
isinstance
(
edge
.
tail
,
OriginNode
):
tail_info
=
f
'
{
edge
.
tail
.
original_graph
.
model
.
model_id
}
_
{
edge
.
tail
.
name
}
'
else
:
tail_info
=
edge
.
tail
.
name
edges_dump
.
append
((
head_info
,
tail_info
))
return
{
'inputs'
:
self
.
input_node
.
operation
.
io_names
,
'outputs'
:
self
.
output_node
.
operation
.
io_names
,
'nodes'
:
nodes_dump
,
'edges'
:
edges_dump
}
def
_fork_to
(
self
,
model
:
Model
)
->
Graph
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
for
node
in
self
.
hidden_nodes
:
if
isinstance
(
node
,
AbstractLogicalNode
):
node
.
_fork_to
(
new_graph
)
else
:
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
).
_register
()
id_to_new_node
=
{
node
.
__repr__
():
node
for
node
in
new_graph
.
nodes
}
for
edge
in
self
.
edges
:
new_head
=
id_to_new_node
[
edge
.
head
.
__repr__
()]
new_tail
=
id_to_new_node
[
edge
.
tail
.
__repr__
()]
Edge
((
new_head
,
edge
.
head_slot
),
(
new_tail
,
edge
.
tail_slot
),
_internal
=
True
).
_register
()
return
new_graph
class
OriginNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
original_graph
:
Graph
,
original_node
:
Node
,
name
:
str
,
operation
,
_internal
=
False
):
super
().
__init__
(
logical_graph
,
original_node
.
id
,
name
,
operation
)
self
.
original_graph
=
original_graph
self
.
original_node
=
original_node
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
model_id
=
self
.
original_node
.
graph
.
model
.
model_id
new_node
=
Node
(
self
.
original_node
.
graph
,
self
.
original_node
.
id
,
f
"M_
{
model_id
}
_"
+
self
.
original_node
.
name
,
self
.
original_node
.
operation
)
return
new_node
,
multi_model_placement
[
self
.
original_node
.
graph
.
model
]
def
__repr__
(
self
):
return
f
'OriginNode(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
\
operation=
{
self
.
operation
}
, origin_model_id=
{
self
.
original_graph
.
model
.
model_id
}
)'
def
_fork_to
(
self
,
graph
:
Graph
):
OriginNode
(
graph
,
self
.
original_graph
,
self
.
original_node
,
self
.
name
,
self
.
operation
).
_register
()
class
LogicalPlan
:
def
__init__
(
self
,
id
=
0
)
->
None
:
self
.
lp_model
=
Model
(
_internal
=
True
)
self
.
id
=
id
self
.
logical_graph
=
LogicalGraph
(
self
.
lp_model
,
id
,
name
=
f
'
{
id
}
'
,
_internal
=
True
).
_register
()
self
.
lp_model
.
_root_graph_name
=
self
.
logical_graph
.
name
self
.
models
=
[]
def
add_model
(
self
,
model
:
Model
):
self
.
models
.
append
(
model
)
# Only optimize the root graph.
self
.
_merge_graph
(
model
.
root_graph
)
def
_merge_graph
(
self
,
from_graph
):
to_graph
=
self
.
logical_graph
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
from_graph
.
nodes
:
new_node
=
OriginNode
(
to_graph
,
old_node
.
graph
,
old_node
,
old_node
.
name
,
old_node
.
operation
,
_internal
=
True
).
_register
()
id_to_new_node
[
old_node
.
id
]
=
new_node
for
edge
in
from_graph
.
edges
:
new_head
=
id_to_new_node
[
edge
.
head
.
id
]
new_tail
=
id_to_new_node
[
edge
.
tail
.
id
]
Edge
((
new_head
,
edge
.
head_slot
),
(
new_tail
,
edge
.
tail_slot
),
_internal
=
True
).
_register
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
\
->
Tuple
[
Model
,
Dict
[
Node
,
PhysicalDevice
],
List
[
Model
]]:
phy_model
=
Model
(
_internal
=
True
)
# self.lp_model.fork()
phy_graph
=
self
.
lp_model
.
root_graph
.
_fork_to
(
phy_model
)
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if
len
(
multi_model_placement
)
>
1
:
phy_model
.
training_config
.
kwargs
[
'is_multi_model'
]
=
True
phy_model
.
training_config
.
kwargs
[
'model_cls'
]
=
phy_graph
.
name
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
]
=
[]
#FIXME: allow user to specify
phy_model
.
training_config
.
module
=
'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
for
model
in
multi_model_placement
:
for
graph_name
in
model
.
graphs
:
if
graph_name
!=
model
.
_root_graph_name
:
model
.
graphs
[
graph_name
].
_fork_to
(
phy_model
,
name_prefix
=
f
'M_
{
model
.
model_id
}
_'
)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot
=
{}
# Model ID -> Slot ID
input_slot_mapping
=
{}
output_slot_mapping
=
{}
# Replace all logical nodes to executable physical nodes
hidden_nodes
=
phy_graph
.
hidden_nodes
.
copy
()
node_placements
=
{}
for
node
in
hidden_nodes
:
if
isinstance
(
node
,
OriginNode
):
model_id
=
node
.
original_graph
.
model
.
model_id
if
node
.
original_graph
.
model
not
in
multi_model_placement
:
for
edge
in
node
.
incoming_edges
:
edge
.
remove
()
for
edge
in
node
.
outgoing_edges
:
edge
.
remove
()
node
.
remove
()
continue
if
isinstance
(
node
,
AbstractLogicalNode
):
new_node
,
placement
=
node
.
assemble
(
multi_model_placement
)
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
model_id
=
new_node
.
graph
.
model
.
model_id
if
model_id
not
in
training_config_slot
:
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
\
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
slot
=
training_config_slot
[
model_id
]
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_output'
]
=
False
else
:
slot
=
training_config_slot
[
model_id
]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if
new_node
.
operation
.
type
==
'_inputs'
:
input_slot_mapping
[
new_node
]
=
slot
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
True
if
new_node
.
operation
.
type
==
'_outputs'
:
output_slot_mapping
[
new_node
]
=
slot
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_output'
]
=
True
self
.
node_replace
(
node
,
new_node
)
if
isinstance
(
new_node
.
operation
,
Cell
):
old_cell_name
=
new_node
.
operation
.
cell_name
new_node
.
operation
=
copy
.
deepcopy
(
new_node
.
operation
)
new_node
.
operation
.
cell_name
=
f
'M_
{
model_id
}
_
{
old_cell_name
}
'
node_placements
[
new_node
]
=
placement
node
.
remove
()
# If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges
=
phy_graph
.
edges
.
copy
()
# Avoid a node is copied multiple times on the same device
copied_op
:
Dict
[
Tuple
(
Node
,
PhysicalDevice
),
Node
]
=
{}
for
edge
in
existing_edges
:
head_placement
=
node_placements
[
edge
.
head
]
tail_placement
=
node_placements
[
edge
.
tail
]
if
head_placement
!=
tail_placement
:
if
head_placement
.
server
!=
tail_placement
.
server
:
raise
ValueError
(
'Cross-server placement is not supported.'
)
# Same server different devices
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
edge
.
head
=
to_node
edge
.
head_slot
=
None
# merge all input nodes into one with multiple slots
input_nodes
=
[]
for
node
in
phy_graph
.
hidden_nodes
:
if
isinstance
(
node
.
operation
,
_IOPseudoOperation
)
and
node
.
operation
.
type
==
'_inputs'
:
input_nodes
.
append
(
node
)
for
edge
in
phy_graph
.
edges
:
if
edge
.
head
in
input_nodes
:
edge
.
head_slot
=
input_slot_mapping
[
edge
.
head
]
edge
.
head
=
phy_graph
.
input_node
# merge all output nodes into one with multiple slots
output_nodes
=
[]
for
node
in
phy_graph
.
hidden_nodes
:
if
isinstance
(
node
.
operation
,
_IOPseudoOperation
)
and
node
.
operation
.
type
==
'_outputs'
:
output_nodes
.
append
(
node
)
for
edge
in
phy_graph
.
edges
:
if
edge
.
tail
in
output_nodes
:
edge
.
tail_slot
=
output_slot_mapping
[
edge
.
tail
]
edge
.
tail
=
phy_graph
.
output_node
for
node
in
input_nodes
:
node
.
remove
()
for
node
in
output_nodes
:
node
.
remove
()
return
phy_model
,
node_placements
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
# TODO: currently, only support single input slot and output slot.
if
input_slot_mapping
!=
None
or
output_slot_mapping
!=
None
:
raise
ValueError
(
'Slot mapping is not supported'
)
phy_graph
=
old_node
.
graph
new_node
.
graph
=
phy_graph
new_node
.
_register
()
for
edge
in
phy_graph
.
edges
:
if
edge
.
head
==
old_node
:
edge
.
head
=
new_node
elif
edge
.
tail
==
old_node
:
edge
.
tail
=
new_node
# after the replacement, there might be multiple duplicated edges
# with the same input and output nodes, which should be de-duplicated
self
.
_remove_duplicated_edges
()
def
_remove_duplicated_edges
(
self
):
# TODO: it does not have duplicated edges if only supporting dedup input
# Duplicated edges appear when a chain of prefix nodes are deduplicated
pass
nni/retiarii/execution/logical_optimizer/opt_batching.py
0 → 100644
View file @
7d1acfbd
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
BatchingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
0 → 100644
View file @
7d1acfbd
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
LogicalPlan
,
AbstractLogicalNode
,
LogicalGraph
,
OriginNode
,
PhysicalDevice
from
nni.retiarii
import
Graph
,
Node
,
Model
from
typing
import
*
from
nni.retiarii.operation
import
_IOPseudoOperation
_supported_training_modules
=
[
'nni.retiarii.trainer.PyTorchImageClassificationTrainer'
]
class
DedupInputNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
id
:
int
,
\
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
id
,
\
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
\
nodes_to_dedup
[
0
].
operation
)
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
for
node
in
self
.
origin_nodes
:
if
node
.
original_graph
.
model
in
multi_model_placement
:
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
\
f
'M_
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
'
,
\
node
.
operation
)
return
new_node
,
multi_model_placement
[
node
.
original_graph
.
model
]
raise
ValueError
(
f
'DedupInputNode
{
self
.
name
}
does not contain nodes from multi_model'
)
def
_fork_to
(
self
,
graph
:
Graph
):
DedupInputNode
(
graph
,
self
.
id
,
self
.
origin_nodes
).
_register
()
def
__repr__
(
self
)
->
str
:
return
f
'DedupNode(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
\
len(nodes_to_dedup)=
{
len
(
self
.
origin_nodes
)
}
'
class
DedupInputOptimizer
(
AbstractOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
_check_deduplicate_by_node
(
self
,
root_node
,
node_to_check
):
if
root_node
==
node_to_check
:
return
True
if
root_node
.
operation
.
type
==
'_inputs'
and
\
node_to_check
.
operation
.
type
==
'_inputs'
and
\
isinstance
(
root_node
,
OriginNode
)
and
\
isinstance
(
node_to_check
,
OriginNode
):
if
root_node
.
original_graph
.
model
.
training_config
.
module
not
in
_supported_training_modules
:
return
False
if
root_node
.
original_graph
.
model
.
training_config
==
node_to_check
.
original_graph
.
model
.
training_config
:
return
True
else
:
return
False
else
:
return
False
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
nodes_to_skip
=
set
()
while
True
:
# repeat until the logical_graph converges
input_nodes
=
logical_plan
.
logical_graph
.
get_nodes_by_type
(
"_inputs"
)
#_PseudoOperation(type_name="_inputs"))
root_node
=
None
for
node
in
input_nodes
:
if
node
in
nodes_to_skip
:
continue
root_node
=
node
break
if
root_node
==
None
:
break
# end of convert
else
:
nodes_to_dedup
=
[]
for
node
in
input_nodes
:
if
node
in
nodes_to_skip
:
continue
if
self
.
_check_deduplicate_by_node
(
root_node
,
node
):
nodes_to_dedup
.
append
(
node
)
assert
(
len
(
nodes_to_dedup
)
>=
1
)
if
len
(
nodes_to_dedup
)
==
1
:
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
\
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
edge
.
head
=
dedup_node
if
edge
.
tail
in
nodes_to_dedup
:
edge
.
tail
=
dedup_node
for
node
in
nodes_to_dedup
:
node
.
remove
()
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
0 → 100644
View file @
7d1acfbd
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
WeightSharingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/graph.py
View file @
7d1acfbd
...
@@ -5,6 +5,7 @@ Model representation.
...
@@ -5,6 +5,7 @@ Model representation.
import
copy
import
copy
from
enum
import
Enum
from
enum
import
Enum
import
json
import
json
from
collections
import
defaultdict
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
...
@@ -51,6 +52,10 @@ class TrainingConfig:
...
@@ -51,6 +52,10 @@ class TrainingConfig:
'kwargs'
:
self
.
kwargs
'kwargs'
:
self
.
kwargs
}
}
def
__eq__
(
self
,
other
):
return
self
.
module
==
other
.
module
and
\
self
.
kwargs
==
other
.
kwargs
class
Model
:
class
Model
:
"""
"""
...
@@ -311,6 +316,13 @@ class Graph:
...
@@ -311,6 +316,13 @@ class Graph:
"""
"""
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
operation
.
type
==
operation_type
]
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
operation
.
type
==
operation_type
]
def
get_node_by_id
(
self
,
id
:
int
)
->
Optional
[
'Node'
]:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
id
==
id
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
label
==
label
]
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
label
==
label
]
...
@@ -347,8 +359,8 @@ class Graph:
...
@@ -347,8 +359,8 @@ class Graph:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
return
self
is
other
return
self
is
other
def
_fork_to
(
self
,
model
:
Model
)
->
'Graph'
:
def
_fork_to
(
self
,
model
:
Model
,
name_prefix
=
''
)
->
'Graph'
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
new_graph
=
Graph
(
model
,
self
.
id
,
name_prefix
+
self
.
name
,
_internal
=
True
).
_register
()
# TODO: use node copy instead
# TODO: use node copy instead
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
...
@@ -544,7 +556,6 @@ class Node:
...
@@ -544,7 +556,6 @@ class Node:
ret
[
'label'
]
=
self
.
label
ret
[
'label'
]
=
self
.
label
return
ret
return
ret
class
Edge
:
class
Edge
:
"""
"""
A tensor, or "data flow", between two nodes.
A tensor, or "data flow", between two nodes.
...
@@ -626,6 +637,6 @@ class IllegalGraphError(ValueError):
...
@@ -626,6 +637,6 @@ class IllegalGraphError(ValueError):
@
staticmethod
@
staticmethod
def
_debug_dump_graph
(
graph
):
def
_debug_dump_graph
(
graph
):
if
isinstance
(
graph
,
Graph
):
if
isinstance
(
graph
,
Graph
):
graph
=
graph
.
dump
()
graph
=
graph
.
_
dump
()
with
open
(
'generated/debug.json'
,
'w'
)
as
dump_file
:
with
open
(
'generated/debug.json'
,
'w'
)
as
dump_file
:
json
.
dump
(
graph
,
dump_file
,
indent
=
4
)
json
.
dump
(
graph
,
dump_file
,
indent
=
4
)
nni/retiarii/integration.py
View file @
7d1acfbd
...
@@ -126,7 +126,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
...
@@ -126,7 +126,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
@
staticmethod
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
def
_process_value
(
value
)
->
Any
:
# hopefully a float
if
isinstance
(
value
,
dict
):
if
isinstance
(
value
,
dict
):
if
'default'
in
value
:
return
value
[
'default'
]
return
value
[
'default'
]
else
:
return
value
return
value
return
value
...
...
nni/retiarii/mutator.py
View file @
7d1acfbd
...
@@ -26,17 +26,13 @@ class Sampler:
...
@@ -26,17 +26,13 @@ class Sampler:
class
Mutator
:
class
Mutator
:
"""
"""
Mutates graphs in model to generate new model.
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
# Method names are open for discussion.
"""
"""
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
):
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
):
...
@@ -55,7 +51,6 @@ class Mutator:
...
@@ -55,7 +51,6 @@ class Mutator:
"""
"""
Apply this mutator on a model.
Apply this mutator on a model.
Returns mutated model.
Returns mutated model.
The model will be copied before mutation and the original model will not be modified.
The model will be copied before mutation and the original model will not be modified.
"""
"""
assert
self
.
sampler
is
not
None
assert
self
.
sampler
is
not
None
...
@@ -86,7 +81,6 @@ class Mutator:
...
@@ -86,7 +81,6 @@ class Mutator:
def
mutate
(
self
,
model
:
Model
)
->
None
:
def
mutate
(
self
,
model
:
Model
)
->
None
:
"""
"""
Abstract method to be implemented by subclass.
Abstract method to be implemented by subclass.
Mutate a model in place.
Mutate a model in place.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
...
...
nni/retiarii/operation.py
View file @
7d1acfbd
...
@@ -120,8 +120,12 @@ class PyTorchOperation(Operation):
...
@@ -120,8 +120,12 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'aten::mean'
:
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::size'
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.size(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::view'
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
else
:
else
:
raise
RuntimeError
(
'unsupported operation type: {
}'
.
format
(
self
.
type
)
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
class
TensorFlowOperation
(
Operation
):
class
TensorFlowOperation
(
Operation
):
def
_to_class_name
(
self
)
->
str
:
def
_to_class_name
(
self
)
->
str
:
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
7d1acfbd
nni/retiarii/operation_def/torch_op_def.py
View file @
7d1acfbd
from
..operation
import
PyTorchOperation
from
..operation
import
PyTorchOperation
class
relu
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
return
''
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
= nn.functional.relu(
{
inputs
[
0
]
}
)'
class
Flatten
(
PyTorchOperation
):
class
Flatten
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
def
to_init_code
(
self
,
field
):
...
@@ -9,6 +17,14 @@ class Flatten(PyTorchOperation):
...
@@ -9,6 +17,14 @@ class Flatten(PyTorchOperation):
assert
len
(
inputs
)
==
1
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
class
ToDevice
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
return
''
def
to_forward_code
(
self
,
field
,
output
,
inputs
)
->
str
:
assert
len
(
inputs
)
==
1
return
f
"
{
output
}
=
{
inputs
[
0
]
}
.to('
{
self
.
parameters
[
'device'
]
}
')"
class
Dense
(
PyTorchOperation
):
class
Dense
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
def
to_init_code
(
self
,
field
):
...
...
nni/retiarii/trainer/__init__.py
View file @
7d1acfbd
from
.interface
import
BaseTrainer
from
.interface
import
BaseTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
nni/retiarii/trainer/pytorch/__init__.py
View file @
7d1acfbd
from
.base
import
PyTorchImageClassificationTrainer
from
.base
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
.darts
import
DartsTrainer
from
.darts
import
DartsTrainer
from
.enas
import
EnasTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
from
.proxyless
import
ProxylessTrainer
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
7d1acfbd
...
@@ -85,11 +85,13 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -85,11 +85,13 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
self
.
_dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
# TODO: we will need at least two (maybe three) data loaders in future.
# TODO: we will need at least two (maybe three) data loaders in future.
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
def
_accuracy
(
self
,
input
,
target
):
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
...
@@ -97,18 +99,32 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -97,18 +99,32 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
return
correct
/
input
.
size
(
0
)
return
correct
/
input
.
size
(
0
)
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
x
,
y
=
batch
x
,
y
=
batch
if
self
.
_use_cuda
:
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
x
,
y
=
x
.
cuda
(
torch
.
device
(
'cuda:0'
)),
y
.
cuda
(
torch
.
device
(
'cuda:0'
))
y_hat
=
self
.
model
(
x
)
return
x
,
y
def
training_step_after_model
(
self
,
x
,
y
,
y_hat
):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
x
,
y
=
batch
x
,
y
=
batch
if
self
.
_use_cuda
:
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
y_hat
=
self
.
model
(
x
)
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
return
{
'val_acc'
:
acc
}
...
@@ -126,9 +142,120 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
...
@@ -126,9 +142,120 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def
_train
(
self
):
def
_train
(
self
):
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
self
.
training_step
(
batch
,
i
)
loss
=
self
.
training_step
(
batch
,
i
)
loss
.
backward
()
def
fit
(
self
)
->
None
:
def
fit
(
self
)
->
None
:
for
_
in
range
(
self
.
_trainer_kwargs
[
'max_epochs'
]):
for
_
in
range
(
self
.
_trainer_kwargs
[
'max_epochs'
]):
self
.
_train
()
self
.
_train
()
nni
.
report_final_result
(
self
.
_validate
()[
'val_acc'
])
# assuming val_acc here
# assuming val_acc here
nni
.
report_final_result
(
self
.
_validate
()[
'val_acc'
])
class
PyTorchMultiModelTrainer
(
BaseTrainer
):
def
__init__
(
self
,
multi_model
,
kwargs
=
[]):
self
.
multi_model
=
multi_model
self
.
kwargs
=
kwargs
self
.
_dataloaders
=
[]
self
.
_datasets
=
[]
self
.
_optimizers
=
[]
self
.
_trainers
=
[]
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
max_steps
=
None
if
'max_steps'
in
self
.
kwargs
:
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
for
m
in
self
.
kwargs
[
'model_kwargs'
]:
if
m
[
'use_input'
]:
dataset_cls
=
m
[
'dataset_cls'
]
dataset_kwargs
=
m
[
'dataset_kwargs'
]
dataloader_kwargs
=
m
[
'dataloader_kwargs'
]
dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
self
.
_dataloaders
.
append
(
dataloader
)
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_kwargs
=
m
[
'optimizer_kwargs'
]
m_header
=
f
"M_
{
m
[
'model_id'
]
}
"
one_model_params
=
[]
for
name
,
param
in
multi_model
.
named_parameters
():
name_prefix
=
'_'
.
join
(
name
.
split
(
'_'
)[:
2
])
if
m_header
==
name_prefix
:
one_model_params
.
append
(
param
)
optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
one_model_params
,
**
(
optimizer_kwargs
or
{}))
self
.
_optimizers
.
append
(
optimizer
)
def
fit
(
self
)
->
None
:
torch
.
autograd
.
set_detect_anomaly
(
True
)
max_epochs
=
max
([
x
[
'trainer_kwargs'
][
'max_epochs'
]
for
x
in
self
.
kwargs
[
'model_kwargs'
]])
for
_
in
range
(
max_epochs
):
self
.
_train
()
def
_train
(
self
):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_dataloaders
)):
for
opt
in
self
.
_optimizers
:
opt
.
zero_grad
()
xs
=
[]
ys
=
[]
for
idx
,
batch
in
enumerate
(
multi_model_batch
):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
ys
.
append
(
y
)
y_hats
=
self
.
multi_model
(
*
xs
)
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
losses
=
[]
report_loss
=
{}
for
output_idx
,
yhat
in
enumerate
(
y_hats
):
if
len
(
ys
)
==
len
(
y_hats
):
loss
=
self
.
training_step_after_model
(
xs
[
output_idx
],
ys
[
output_idx
],
yhat
)
elif
len
(
ys
)
==
1
:
loss
=
self
.
training_step_after_model
(
xs
[
0
],
ys
[
0
].
to
(
yhat
.
get_device
()),
yhat
)
else
:
raise
ValueError
(
'len(ys) should be either 1 or len(y_hats)'
)
losses
.
append
(
loss
.
to
(
"cuda:0"
))
report_loss
[
self
.
kwargs
[
'model_kwargs'
][
output_idx
][
'model_id'
]]
=
loss
.
item
()
summed_loss
=
sum
(
losses
)
summed_loss
.
backward
()
for
opt
in
self
.
_optimizers
:
opt
.
step
()
if
batch_idx
%
50
==
0
:
nni
.
report_intermediate_result
(
report_loss
)
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
if
device
:
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
))
return
x
,
y
def
training_step_after_model
(
self
,
x
,
y
,
y_hat
):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
x
,
y
=
batch
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment