Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7d1acfbd
Unverified
Commit
7d1acfbd
authored
Dec 11, 2020
by
Zhenhua Han
Committed by
GitHub
Dec 11, 2020
Browse files
[Retiarii] cross-graph optimization: input deduplication (#3105)
parent
165756cc
Changes
30
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
767 additions
and
30 deletions
+767
-30
.gitignore
.gitignore
+3
-0
nni/retiarii/__init__.py
nni/retiarii/__init__.py
+2
-1
nni/retiarii/codegen/pytorch.py
nni/retiarii/codegen/pytorch.py
+11
-5
nni/retiarii/execution/api.py
nni/retiarii/execution/api.py
+6
-1
nni/retiarii/execution/cgo_engine.py
nni/retiarii/execution/cgo_engine.py
+151
-0
nni/retiarii/execution/logical_optimizer/__init__.py
nni/retiarii/execution/logical_optimizer/__init__.py
+0
-0
nni/retiarii/execution/logical_optimizer/interface.py
nni/retiarii/execution/logical_optimizer/interface.py
+11
-0
nni/retiarii/execution/logical_optimizer/logical_plan.py
nni/retiarii/execution/logical_optimizer/logical_plan.py
+297
-0
nni/retiarii/execution/logical_optimizer/opt_batching.py
nni/retiarii/execution/logical_optimizer/opt_batching.py
+10
-0
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
+88
-0
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
...etiarii/execution/logical_optimizer/opt_weight_sharing.py
+10
-0
nni/retiarii/graph.py
nni/retiarii/graph.py
+15
-4
nni/retiarii/integration.py
nni/retiarii/integration.py
+4
-1
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+1
-7
nni/retiarii/operation.py
nni/retiarii/operation.py
+5
-1
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+1
-1
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+16
-0
nni/retiarii/trainer/__init__.py
nni/retiarii/trainer/__init__.py
+1
-1
nni/retiarii/trainer/pytorch/__init__.py
nni/retiarii/trainer/pytorch/__init__.py
+1
-1
nni/retiarii/trainer/pytorch/base.py
nni/retiarii/trainer/pytorch/base.py
+134
-7
No files found.
.gitignore
View file @
7d1acfbd
...
...
@@ -97,3 +97,6 @@ venv.bak/
# VSCode
.vscode
.vs
.history
generated/
test/ut/retiarii/_debug_graph_data.json
nni/retiarii/__init__.py
View file @
7d1acfbd
from
.
execu
tion
import
*
from
.
opera
tion
import
Operation
from
.graph
import
*
from
.execution
import
*
from
.mutator
import
*
from
.model_apis
import
nn
nni/retiarii/codegen/pytorch.py
View file @
7d1acfbd
...
...
@@ -7,11 +7,12 @@ from ..operation import Operation, Cell
_logger
=
logging
.
getLogger
(
__name__
)
def
model_to_pytorch_script
(
model
:
Model
)
->
str
:
def
model_to_pytorch_script
(
model
:
Model
,
placement
=
None
)
->
str
:
graphs
=
[]
total_pkgs
=
set
()
for
name
,
cell
in
model
.
graphs
.
items
():
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
)
import_pkgs
,
graph_code
=
graph_to_pytorch_model
(
name
,
cell
,
placement
=
placement
)
graphs
.
append
(
graph_code
)
total_pkgs
.
update
(
import_pkgs
)
# TODO: set correct PATH for the packages (after launch refactor)
...
...
@@ -23,6 +24,7 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
_logger
.
info
(
'sorted_incoming_edges: {}'
.
format
(
edges
))
if
not
edges
:
return
[]
_logger
.
info
(
f
'all tail_slots are None:
{
[
edge
.
tail_slot
for
edge
in
edges
]
}
'
)
if
all
(
edge
.
tail_slot
is
None
for
edge
in
edges
):
return
edges
if
all
(
isinstance
(
edge
.
tail_slot
,
int
)
for
edge
in
edges
):
...
...
@@ -52,13 +54,14 @@ def _format_inputs(node: Node) -> List[str]:
inputs
.
append
(
'{}[{}]'
.
format
(
edge
.
head
.
name
,
edge
.
head_slot
))
return
inputs
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
)
->
str
:
nodes
=
graph
.
nodes
def
graph_to_pytorch_model
(
graph_name
:
str
,
graph
:
Graph
,
placement
=
None
)
->
str
:
nodes
=
graph
.
topo_sort
()
# FIXME: topological sort is needed here
# handle module node and function node differently
# only need to generate code for module here
import_pkgs
=
set
()
node_codes
=
[]
placement_codes
=
[]
for
node
in
nodes
:
if
node
.
operation
:
pkg_name
=
node
.
operation
.
get_import_pkg
()
...
...
@@ -66,6 +69,9 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph) -> str:
import_pkgs
.
add
(
pkg_name
)
node_code
=
node
.
operation
.
to_init_code
(
node
.
name
)
if
node_code
is
not
None
:
if
placement
and
node
in
placement
and
len
(
node_code
)
>
0
:
node_codes
.
append
(
f
"
{
node_code
}
.to('
{
placement
[
node
].
device
}
')"
)
else
:
node_codes
.
append
(
node_code
)
if
graph
.
input_node
.
operation
.
io_names
is
None
:
...
...
nni/retiarii/execution/api.py
View file @
7d1acfbd
import
time
import
os
import
importlib.util
from
typing
import
*
from
..graph
import
Model
,
ModelStatus
from
.base
import
BaseExecutionEngine
from
.cgo_engine
import
CGOExecutionEngine
from
.interface
import
*
from
.listener
import
DefaultListener
...
...
@@ -21,6 +23,9 @@ def get_execution_engine() -> BaseExecutionEngine:
"""
global
_execution_engine
if
_execution_engine
is
None
:
if
os
.
environ
.
get
(
'CGO'
)
==
'true'
:
_execution_engine
=
CGOExecutionEngine
()
else
:
_execution_engine
=
BaseExecutionEngine
()
return
_execution_engine
...
...
nni/retiarii/execution/cgo_engine.py
0 → 100644
View file @
7d1acfbd
import
logging
import
json
from
typing
import
*
from
.interface
import
AbstractExecutionEngine
,
AbstractGraphListener
,
WorkerInfo
from
..
import
codegen
,
utils
from
..graph
import
Model
,
ModelStatus
,
MetricData
from
..integration
import
send_trial
,
receive_trial_parameters
,
get_advisor
from
.logical_optimizer.logical_plan
import
LogicalPlan
,
PhysicalDevice
from
.logical_optimizer.opt_dedup_input
import
DedupInputOptimizer
from
.base
import
BaseGraphData
_logger
=
logging
.
getLogger
(
__name__
)
class
CGOExecutionEngine
(
AbstractExecutionEngine
):
def
__init__
(
self
,
n_model_per_graph
=
4
)
->
None
:
self
.
_listeners
:
List
[
AbstractGraphListener
]
=
[]
self
.
_running_models
:
Dict
[
int
,
Model
]
=
dict
()
self
.
logical_plan_counter
=
0
self
.
n_model_per_graph
=
n_model_per_graph
self
.
_optimizers
=
[
DedupInputOptimizer
()]
self
.
_original_models
=
{}
self
.
_original_model_to_multi_model
=
{}
# register advisor callbacks
advisor
=
get_advisor
()
advisor
.
send_trial_callback
=
self
.
_send_trial_callback
advisor
.
request_trial_jobs_callback
=
self
.
_request_trial_jobs_callback
advisor
.
trial_end_callback
=
self
.
_trial_end_callback
advisor
.
intermediate_metric_callback
=
self
.
_intermediate_metric_callback
advisor
.
final_metric_callback
=
self
.
_final_metric_callback
def
add_optimizer
(
self
,
opt
):
self
.
_optimizers
.
append
(
opt
)
def
submit_models
(
self
,
*
models
:
List
[
Model
])
->
None
:
_logger
.
info
(
f
'
{
len
(
models
)
}
Models are submitted'
)
logical
=
self
.
_build_logical
(
models
)
for
opt
in
self
.
_optimizers
:
opt
.
convert
(
logical
)
phy_models_and_placements
=
self
.
_assemble
(
logical
)
for
model
,
placement
,
grouped_models
in
phy_models_and_placements
:
data
=
BaseGraphData
(
codegen
.
model_to_pytorch_script
(
model
,
placement
=
placement
),
model
.
training_config
.
module
,
model
.
training_config
.
kwargs
)
for
m
in
grouped_models
:
self
.
_original_models
[
m
.
model_id
]
=
m
self
.
_original_model_to_multi_model
[
m
.
model_id
]
=
model
self
.
_running_models
[
send_trial
(
data
.
dump
())]
=
model
# for model in models:
# data = BaseGraphData(codegen.model_to_pytorch_script(model),
# model.config['trainer_module'], model.config['trainer_kwargs'])
# self._running_models[send_trial(data.dump())] = model
def
_assemble
(
self
,
logical_plan
:
LogicalPlan
)
->
List
[
Tuple
[
Model
,
PhysicalDevice
]]:
# unique_models = set()
# for node in logical_plan.graph.nodes:
# if node.graph.model not in unique_models:
# unique_models.add(node.graph.model)
# return [m for m in unique_models]
grouped_models
:
List
[
Dict
[
Model
,
PhysicalDevice
]]
=
AssemblePolicy
().
group
(
logical_plan
)
phy_models_and_placements
=
[]
for
multi_model
in
grouped_models
:
model
,
model_placement
=
logical_plan
.
assemble
(
multi_model
)
phy_models_and_placements
.
append
((
model
,
model_placement
,
multi_model
.
keys
()))
return
phy_models_and_placements
def
_build_logical
(
self
,
models
:
List
[
Model
])
->
LogicalPlan
:
logical_plan
=
LogicalPlan
(
id
=
self
.
logical_plan_counter
)
for
model
in
models
:
logical_plan
.
add_model
(
model
)
self
.
logical_plan_counter
+=
1
return
logical_plan
def
register_graph_listener
(
self
,
listener
:
AbstractGraphListener
)
->
None
:
self
.
_listeners
.
append
(
listener
)
def
_send_trial_callback
(
self
,
paramater
:
dict
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_used
(
0
)
# FIXME: find the real resource id
def
_request_trial_jobs_callback
(
self
,
num_trials
:
int
)
->
None
:
for
listener
in
self
.
_listeners
:
listener
.
on_resource_available
([
0
]
*
num_trials
)
# FIXME: find the real resource id
def
_trial_end_callback
(
self
,
trial_id
:
int
,
success
:
bool
)
->
None
:
model
=
self
.
_running_models
[
trial_id
]
if
success
:
model
.
status
=
ModelStatus
.
Trained
else
:
model
.
status
=
ModelStatus
.
Failed
for
model_id
in
self
.
_original_model_to_multi_model
:
if
self
.
_original_model_to_multi_model
[
model_id
]
==
model
:
original_model
=
self
.
_original_models
[
model_id
]
if
success
:
original_model
.
status
=
ModelStatus
.
Trained
else
:
original_model
.
status
=
ModelStatus
.
Failed
for
listener
in
self
.
_listeners
:
listener
.
on_training_end
(
original_model
,
success
)
def
_intermediate_metric_callback
(
self
,
trial_id
:
int
,
metrics
:
MetricData
)
->
None
:
# model = self._running_models[trial_id]
merged_metrics
=
dict
(
metrics
)
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_intermediate_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
def
_final_metric_callback
(
self
,
trial_id
:
int
,
metrics
:
MetricData
)
->
None
:
merged_metrics
=
dict
(
metrics
)
for
model_id
in
merged_metrics
:
int_model_id
=
int
(
model_id
)
self
.
_original_models
[
int_model_id
].
intermediate_metrics
.
append
(
merged_metrics
[
model_id
])
#model.intermediate_metrics.append(metrics)
for
listener
in
self
.
_listeners
:
listener
.
on_metric
(
self
.
_original_models
[
int_model_id
],
merged_metrics
[
model_id
])
def
query_available_resource
(
self
)
->
List
[
WorkerInfo
]:
raise
NotImplementedError
# move the method from listener to here?
@
classmethod
def
trial_execute_graph
(
cls
)
->
None
:
"""
Initialize the model, hand it over to trainer.
"""
graph_data
=
BaseGraphData
.
load
(
receive_trial_parameters
())
_logger
.
info
(
'CGO_ENGINE trial parameters received'
)
with
open
(
'_generated_model.py'
,
'w'
)
as
f
:
f
.
write
(
graph_data
.
model_script
)
# with open('_debug_graph_data.json', 'w') as f:
# json.dump(graph_data.dump(), f)
trainer_cls
=
utils
.
import_
(
graph_data
.
training_module
)
model_cls
=
utils
.
import_
(
f
"_generated_model.
{
graph_data
.
training_kwargs
[
'model_cls'
]
}
"
)
trainer_instance
=
trainer_cls
(
model_cls
(),
graph_data
.
training_kwargs
)
trainer_instance
.
fit
()
class
AssemblePolicy
:
@
staticmethod
def
group
(
logical_plan
):
group_model
=
{}
for
idx
,
m
in
enumerate
(
logical_plan
.
models
):
group_model
[
m
]
=
PhysicalDevice
(
'server'
,
f
'cuda:
{
idx
}
'
)
return
[
group_model
]
\ No newline at end of file
nni/retiarii/execution/logical_optimizer/__init__.py
0 → 100644
View file @
7d1acfbd
nni/retiarii/execution/logical_optimizer/interface.py
0 → 100644
View file @
7d1acfbd
from
abc
import
*
from
typing
import
*
from
.logical_plan
import
LogicalPlan
class
AbstractOptimizer
(
ABC
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
raise
NotImplementedError
nni/retiarii/execution/logical_optimizer/logical_plan.py
0 → 100644
View file @
7d1acfbd
from
nni.retiarii.operation
import
Operation
from
nni.retiarii.graph
import
Model
,
Graph
,
Edge
,
Node
,
Cell
from
typing
import
*
import
logging
from
nni.retiarii.operation
import
_IOPseudoOperation
import
copy
class
PhysicalDevice
:
def
__init__
(
self
,
server
:
str
,
device
:
str
):
self
.
server
=
server
self
.
device
=
device
def
__eq__
(
self
,
o
)
->
bool
:
return
self
.
server
==
o
.
server
and
self
.
device
==
o
.
device
def
__hash__
(
self
)
->
int
:
return
hash
(
self
.
server
+
'_'
+
self
.
device
)
class
AbstractLogicalNode
(
Node
):
def
__init__
(
self
,
graph
,
node_id
,
name
,
operation
,
_internal
=
False
):
super
().
__init__
(
graph
,
node_id
,
name
,
operation
,
_internal
=
_internal
)
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
raise
NotImplementedError
def
_fork_to
(
self
,
graph
:
Graph
):
raise
NotImplementedError
class
LogicalGraph
(
Graph
):
def
__init__
(
self
,
model
:
Model
,
graph_id
:
int
,
name
:
str
=
None
,
_internal
:
bool
=
False
):
super
().
__init__
(
model
,
graph_id
,
name
=
'logical_'
+
name
,
_internal
=
_internal
)
def
_dump
(
self
)
->
Any
:
nodes_dump
=
{}
for
node
in
self
.
hidden_nodes
:
if
isinstance
(
node
,
OriginNode
):
nodes_dump
[
f
"
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
(
)
else
:
nodes_dump
[
f
"
{
node
.
graph
.
model
.
model_id
}
_
{
node
.
name
}
"
]
=
node
.
_dump
()
edges_dump
=
[]
for
edge
in
self
.
edges
:
if
isinstance
(
edge
.
head
,
OriginNode
):
head_info
=
f
'
{
edge
.
head
.
original_graph
.
model
.
model_id
}
_
{
edge
.
head
.
name
}
'
else
:
head_info
=
edge
.
head
.
name
if
isinstance
(
edge
.
tail
,
OriginNode
):
tail_info
=
f
'
{
edge
.
tail
.
original_graph
.
model
.
model_id
}
_
{
edge
.
tail
.
name
}
'
else
:
tail_info
=
edge
.
tail
.
name
edges_dump
.
append
((
head_info
,
tail_info
))
return
{
'inputs'
:
self
.
input_node
.
operation
.
io_names
,
'outputs'
:
self
.
output_node
.
operation
.
io_names
,
'nodes'
:
nodes_dump
,
'edges'
:
edges_dump
}
def
_fork_to
(
self
,
model
:
Model
)
->
Graph
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
for
node
in
self
.
hidden_nodes
:
if
isinstance
(
node
,
AbstractLogicalNode
):
node
.
_fork_to
(
new_graph
)
else
:
Node
(
new_graph
,
node
.
id
,
node
.
name
,
node
.
operation
,
_internal
=
True
).
_register
()
id_to_new_node
=
{
node
.
__repr__
():
node
for
node
in
new_graph
.
nodes
}
for
edge
in
self
.
edges
:
new_head
=
id_to_new_node
[
edge
.
head
.
__repr__
()]
new_tail
=
id_to_new_node
[
edge
.
tail
.
__repr__
()]
Edge
((
new_head
,
edge
.
head_slot
),
(
new_tail
,
edge
.
tail_slot
),
_internal
=
True
).
_register
()
return
new_graph
class
OriginNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
original_graph
:
Graph
,
original_node
:
Node
,
name
:
str
,
operation
,
_internal
=
False
):
super
().
__init__
(
logical_graph
,
original_node
.
id
,
name
,
operation
)
self
.
original_graph
=
original_graph
self
.
original_node
=
original_node
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
model_id
=
self
.
original_node
.
graph
.
model
.
model_id
new_node
=
Node
(
self
.
original_node
.
graph
,
self
.
original_node
.
id
,
f
"M_
{
model_id
}
_"
+
self
.
original_node
.
name
,
self
.
original_node
.
operation
)
return
new_node
,
multi_model_placement
[
self
.
original_node
.
graph
.
model
]
def
__repr__
(
self
):
return
f
'OriginNode(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
\
operation=
{
self
.
operation
}
, origin_model_id=
{
self
.
original_graph
.
model
.
model_id
}
)'
def
_fork_to
(
self
,
graph
:
Graph
):
OriginNode
(
graph
,
self
.
original_graph
,
self
.
original_node
,
self
.
name
,
self
.
operation
).
_register
()
class
LogicalPlan
:
def
__init__
(
self
,
id
=
0
)
->
None
:
self
.
lp_model
=
Model
(
_internal
=
True
)
self
.
id
=
id
self
.
logical_graph
=
LogicalGraph
(
self
.
lp_model
,
id
,
name
=
f
'
{
id
}
'
,
_internal
=
True
).
_register
()
self
.
lp_model
.
_root_graph_name
=
self
.
logical_graph
.
name
self
.
models
=
[]
def
add_model
(
self
,
model
:
Model
):
self
.
models
.
append
(
model
)
# Only optimize the root graph.
self
.
_merge_graph
(
model
.
root_graph
)
def
_merge_graph
(
self
,
from_graph
):
to_graph
=
self
.
logical_graph
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
from_graph
.
nodes
:
new_node
=
OriginNode
(
to_graph
,
old_node
.
graph
,
old_node
,
old_node
.
name
,
old_node
.
operation
,
_internal
=
True
).
_register
()
id_to_new_node
[
old_node
.
id
]
=
new_node
for
edge
in
from_graph
.
edges
:
new_head
=
id_to_new_node
[
edge
.
head
.
id
]
new_tail
=
id_to_new_node
[
edge
.
tail
.
id
]
Edge
((
new_head
,
edge
.
head_slot
),
(
new_tail
,
edge
.
tail_slot
),
_internal
=
True
).
_register
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
\
->
Tuple
[
Model
,
Dict
[
Node
,
PhysicalDevice
],
List
[
Model
]]:
phy_model
=
Model
(
_internal
=
True
)
# self.lp_model.fork()
phy_graph
=
self
.
lp_model
.
root_graph
.
_fork_to
(
phy_model
)
# Add a flag to mark multi-model in graph json.
# Multi-model has a list of training configs in kwargs['model_kwargs']
if
len
(
multi_model_placement
)
>
1
:
phy_model
.
training_config
.
kwargs
[
'is_multi_model'
]
=
True
phy_model
.
training_config
.
kwargs
[
'model_cls'
]
=
phy_graph
.
name
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
]
=
[]
#FIXME: allow user to specify
phy_model
.
training_config
.
module
=
'nni.retiarii.trainer.PyTorchMultiModelTrainer'
# merge sub-graphs
for
model
in
multi_model_placement
:
for
graph_name
in
model
.
graphs
:
if
graph_name
!=
model
.
_root_graph_name
:
model
.
graphs
[
graph_name
].
_fork_to
(
phy_model
,
name_prefix
=
f
'M_
{
model
.
model_id
}
_'
)
# When replace logical nodes, merge the training configs when
# input/output nodes are replaced.
training_config_slot
=
{}
# Model ID -> Slot ID
input_slot_mapping
=
{}
output_slot_mapping
=
{}
# Replace all logical nodes to executable physical nodes
hidden_nodes
=
phy_graph
.
hidden_nodes
.
copy
()
node_placements
=
{}
for
node
in
hidden_nodes
:
if
isinstance
(
node
,
OriginNode
):
model_id
=
node
.
original_graph
.
model
.
model_id
if
node
.
original_graph
.
model
not
in
multi_model_placement
:
for
edge
in
node
.
incoming_edges
:
edge
.
remove
()
for
edge
in
node
.
outgoing_edges
:
edge
.
remove
()
node
.
remove
()
continue
if
isinstance
(
node
,
AbstractLogicalNode
):
new_node
,
placement
=
node
.
assemble
(
multi_model_placement
)
if
isinstance
(
new_node
.
operation
,
_IOPseudoOperation
):
model_id
=
new_node
.
graph
.
model
.
model_id
if
model_id
not
in
training_config_slot
:
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
].
append
(
new_node
.
graph
.
model
.
training_config
.
kwargs
.
copy
())
training_config_slot
[
model_id
]
=
\
len
(
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
])
-
1
slot
=
training_config_slot
[
model_id
]
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'model_id'
]
=
model_id
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
False
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_output'
]
=
False
else
:
slot
=
training_config_slot
[
model_id
]
# If a model's inputs/outputs are not used in the multi-model
# the codegen and trainer should not generate and use them
# "use_input" and "use_output" are used to mark whether
# an input/output of a model is used in a multi-model
if
new_node
.
operation
.
type
==
'_inputs'
:
input_slot_mapping
[
new_node
]
=
slot
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_input'
]
=
True
if
new_node
.
operation
.
type
==
'_outputs'
:
output_slot_mapping
[
new_node
]
=
slot
phy_model
.
training_config
.
kwargs
[
'model_kwargs'
][
slot
][
'use_output'
]
=
True
self
.
node_replace
(
node
,
new_node
)
if
isinstance
(
new_node
.
operation
,
Cell
):
old_cell_name
=
new_node
.
operation
.
cell_name
new_node
.
operation
=
copy
.
deepcopy
(
new_node
.
operation
)
new_node
.
operation
.
cell_name
=
f
'M_
{
model_id
}
_
{
old_cell_name
}
'
node_placements
[
new_node
]
=
placement
node
.
remove
()
# If two nodes are placed on different devices, use ToDevice op to copy the node
existing_edges
=
phy_graph
.
edges
.
copy
()
# Avoid a node is copied multiple times on the same device
copied_op
:
Dict
[
Tuple
(
Node
,
PhysicalDevice
),
Node
]
=
{}
for
edge
in
existing_edges
:
head_placement
=
node_placements
[
edge
.
head
]
tail_placement
=
node_placements
[
edge
.
tail
]
if
head_placement
!=
tail_placement
:
if
head_placement
.
server
!=
tail_placement
.
server
:
raise
ValueError
(
'Cross-server placement is not supported.'
)
# Same server different devices
if
(
edge
.
head
,
tail_placement
)
in
copied_op
:
to_node
=
copied_op
[(
edge
.
head
,
tail_placement
)]
else
:
to_operation
=
Operation
.
new
(
'ToDevice'
,
{
"device"
:
tail_placement
.
device
})
to_node
=
Node
(
phy_graph
,
phy_model
.
_uid
(),
edge
.
head
.
name
+
"_to_"
+
edge
.
tail
.
name
,
to_operation
).
_register
()
Edge
((
edge
.
head
,
edge
.
head_slot
),
(
to_node
,
None
),
_internal
=
True
).
_register
()
copied_op
[(
edge
.
head
,
tail_placement
)]
=
to_node
edge
.
head
=
to_node
edge
.
head_slot
=
None
# merge all input nodes into one with multiple slots
input_nodes
=
[]
for
node
in
phy_graph
.
hidden_nodes
:
if
isinstance
(
node
.
operation
,
_IOPseudoOperation
)
and
node
.
operation
.
type
==
'_inputs'
:
input_nodes
.
append
(
node
)
for
edge
in
phy_graph
.
edges
:
if
edge
.
head
in
input_nodes
:
edge
.
head_slot
=
input_slot_mapping
[
edge
.
head
]
edge
.
head
=
phy_graph
.
input_node
# merge all output nodes into one with multiple slots
output_nodes
=
[]
for
node
in
phy_graph
.
hidden_nodes
:
if
isinstance
(
node
.
operation
,
_IOPseudoOperation
)
and
node
.
operation
.
type
==
'_outputs'
:
output_nodes
.
append
(
node
)
for
edge
in
phy_graph
.
edges
:
if
edge
.
tail
in
output_nodes
:
edge
.
tail_slot
=
output_slot_mapping
[
edge
.
tail
]
edge
.
tail
=
phy_graph
.
output_node
for
node
in
input_nodes
:
node
.
remove
()
for
node
in
output_nodes
:
node
.
remove
()
return
phy_model
,
node_placements
def
node_replace
(
self
,
old_node
:
Node
,
new_node
:
Node
,
input_slot_mapping
=
None
,
output_slot_mapping
=
None
):
# TODO: currently, only support single input slot and output slot.
if
input_slot_mapping
!=
None
or
output_slot_mapping
!=
None
:
raise
ValueError
(
'Slot mapping is not supported'
)
phy_graph
=
old_node
.
graph
new_node
.
graph
=
phy_graph
new_node
.
_register
()
for
edge
in
phy_graph
.
edges
:
if
edge
.
head
==
old_node
:
edge
.
head
=
new_node
elif
edge
.
tail
==
old_node
:
edge
.
tail
=
new_node
# after the replacement, there might be multiple duplicated edges
# with the same input and output nodes, which should be de-duplicated
self
.
_remove_duplicated_edges
()
def
_remove_duplicated_edges
(
self
):
# TODO: it does not have duplicated edges if only supporting dedup input
# Duplicated edges appear when a chain of prefix nodes are deduplicated
pass
nni/retiarii/execution/logical_optimizer/opt_batching.py
0 → 100644
View file @
7d1acfbd
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
BatchingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/execution/logical_optimizer/opt_dedup_input.py
0 → 100644
View file @
7d1acfbd
from
.interface
import
AbstractOptimizer
from
.logical_plan
import
LogicalPlan
,
AbstractLogicalNode
,
LogicalGraph
,
OriginNode
,
PhysicalDevice
from
nni.retiarii
import
Graph
,
Node
,
Model
from
typing
import
*
from
nni.retiarii.operation
import
_IOPseudoOperation
_supported_training_modules
=
[
'nni.retiarii.trainer.PyTorchImageClassificationTrainer'
]
class
DedupInputNode
(
AbstractLogicalNode
):
def
__init__
(
self
,
logical_graph
:
LogicalGraph
,
id
:
int
,
\
nodes_to_dedup
:
List
[
Node
],
_internal
=
False
):
super
().
__init__
(
logical_graph
,
id
,
\
"Dedup_"
+
nodes_to_dedup
[
0
].
name
,
\
nodes_to_dedup
[
0
].
operation
)
self
.
origin_nodes
:
List
[
OriginNode
]
=
nodes_to_dedup
.
copy
()
def
assemble
(
self
,
multi_model_placement
:
Dict
[
Model
,
PhysicalDevice
])
->
Tuple
[
Node
,
PhysicalDevice
]:
for
node
in
self
.
origin_nodes
:
if
node
.
original_graph
.
model
in
multi_model_placement
:
new_node
=
Node
(
node
.
original_graph
,
node
.
id
,
\
f
'M_
{
node
.
original_graph
.
model
.
model_id
}
_
{
node
.
name
}
'
,
\
node
.
operation
)
return
new_node
,
multi_model_placement
[
node
.
original_graph
.
model
]
raise
ValueError
(
f
'DedupInputNode
{
self
.
name
}
does not contain nodes from multi_model'
)
def
_fork_to
(
self
,
graph
:
Graph
):
DedupInputNode
(
graph
,
self
.
id
,
self
.
origin_nodes
).
_register
()
def
__repr__
(
self
)
->
str
:
return
f
'DedupNode(id=
{
self
.
id
}
, name=
{
self
.
name
}
,
\
len(nodes_to_dedup)=
{
len
(
self
.
origin_nodes
)
}
'
class
DedupInputOptimizer
(
AbstractOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
_check_deduplicate_by_node
(
self
,
root_node
,
node_to_check
):
if
root_node
==
node_to_check
:
return
True
if
root_node
.
operation
.
type
==
'_inputs'
and
\
node_to_check
.
operation
.
type
==
'_inputs'
and
\
isinstance
(
root_node
,
OriginNode
)
and
\
isinstance
(
node_to_check
,
OriginNode
):
if
root_node
.
original_graph
.
model
.
training_config
.
module
not
in
_supported_training_modules
:
return
False
if
root_node
.
original_graph
.
model
.
training_config
==
node_to_check
.
original_graph
.
model
.
training_config
:
return
True
else
:
return
False
else
:
return
False
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
nodes_to_skip
=
set
()
while
True
:
# repeat until the logical_graph converges
input_nodes
=
logical_plan
.
logical_graph
.
get_nodes_by_type
(
"_inputs"
)
#_PseudoOperation(type_name="_inputs"))
root_node
=
None
for
node
in
input_nodes
:
if
node
in
nodes_to_skip
:
continue
root_node
=
node
break
if
root_node
==
None
:
break
# end of convert
else
:
nodes_to_dedup
=
[]
for
node
in
input_nodes
:
if
node
in
nodes_to_skip
:
continue
if
self
.
_check_deduplicate_by_node
(
root_node
,
node
):
nodes_to_dedup
.
append
(
node
)
assert
(
len
(
nodes_to_dedup
)
>=
1
)
if
len
(
nodes_to_dedup
)
==
1
:
assert
(
nodes_to_dedup
[
0
]
==
root_node
)
nodes_to_skip
.
add
(
root_node
)
else
:
dedup_node
=
DedupInputNode
(
logical_plan
.
logical_graph
,
\
logical_plan
.
lp_model
.
_uid
(),
nodes_to_dedup
).
_register
()
for
edge
in
logical_plan
.
logical_graph
.
edges
:
if
edge
.
head
in
nodes_to_dedup
:
edge
.
head
=
dedup_node
if
edge
.
tail
in
nodes_to_dedup
:
edge
.
tail
=
dedup_node
for
node
in
nodes_to_dedup
:
node
.
remove
()
nni/retiarii/execution/logical_optimizer/opt_weight_sharing.py
0 → 100644
View file @
7d1acfbd
from
.base_optimizer
import
BaseOptimizer
from
.logical_plan
import
LogicalPlan
class
WeightSharingOptimizer
(
BaseOptimizer
):
def
__init__
(
self
)
->
None
:
pass
def
convert
(
self
,
logical_plan
:
LogicalPlan
)
->
None
:
pass
nni/retiarii/graph.py
View file @
7d1acfbd
...
...
@@ -5,6 +5,7 @@ Model representation.
import
copy
from
enum
import
Enum
import
json
from
collections
import
defaultdict
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
...
...
@@ -51,6 +52,10 @@ class TrainingConfig:
'kwargs'
:
self
.
kwargs
}
def
__eq__
(
self
,
other
):
return
self
.
module
==
other
.
module
and
\
self
.
kwargs
==
other
.
kwargs
class
Model
:
"""
...
...
@@ -311,6 +316,13 @@ class Graph:
"""
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
operation
.
type
==
operation_type
]
def
get_node_by_id
(
self
,
id
:
int
)
->
Optional
[
'Node'
]:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
id
==
id
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_label
(
self
,
label
:
str
)
->
List
[
'Node'
]:
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
label
==
label
]
...
...
@@ -347,8 +359,8 @@ class Graph:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
return
self
is
other
def
_fork_to
(
self
,
model
:
Model
)
->
'Graph'
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
def
_fork_to
(
self
,
model
:
Model
,
name_prefix
=
''
)
->
'Graph'
:
new_graph
=
Graph
(
model
,
self
.
id
,
name_prefix
+
self
.
name
,
_internal
=
True
).
_register
()
# TODO: use node copy instead
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
...
...
@@ -544,7 +556,6 @@ class Node:
ret
[
'label'
]
=
self
.
label
return
ret
class
Edge
:
"""
A tensor, or "data flow", between two nodes.
...
...
@@ -626,6 +637,6 @@ class IllegalGraphError(ValueError):
@
staticmethod
def
_debug_dump_graph
(
graph
):
if
isinstance
(
graph
,
Graph
):
graph
=
graph
.
dump
()
graph
=
graph
.
_
dump
()
with
open
(
'generated/debug.json'
,
'w'
)
as
dump_file
:
json
.
dump
(
graph
,
dump_file
,
indent
=
4
)
nni/retiarii/integration.py
View file @
7d1acfbd
...
...
@@ -126,7 +126,10 @@ class RetiariiAdvisor(MsgDispatcherBase):
@
staticmethod
def
_process_value
(
value
)
->
Any
:
# hopefully a float
if
isinstance
(
value
,
dict
):
if
'default'
in
value
:
return
value
[
'default'
]
else
:
return
value
return
value
...
...
nni/retiarii/mutator.py
View file @
7d1acfbd
...
...
@@ -26,17 +26,13 @@ class Sampler:
class
Mutator
:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
"""
def
__init__
(
self
,
sampler
:
Optional
[
Sampler
]
=
None
):
...
...
@@ -55,7 +51,6 @@ class Mutator:
"""
Apply this mutator on a model.
Returns mutated model.
The model will be copied before mutation and the original model will not be modified.
"""
assert
self
.
sampler
is
not
None
...
...
@@ -86,7 +81,6 @@ class Mutator:
def
mutate
(
self
,
model
:
Model
)
->
None
:
"""
Abstract method to be implemented by subclass.
Mutate a model in place.
"""
raise
NotImplementedError
()
...
...
nni/retiarii/operation.py
View file @
7d1acfbd
...
...
@@ -120,8 +120,12 @@ class PyTorchOperation(Operation):
return
f
'
{
output
}
= [
{
", "
.
join
(
inputs
)
}
]'
elif
self
.
type
==
'aten::mean'
:
return
f
'
{
output
}
= torch.mean(
{
inputs
[
0
]
}
,
{
", "
.
join
(
inputs
[
1
:
-
1
])
}
, out=
{
inputs
[
-
1
]
}
)'
elif
self
.
type
==
'aten::size'
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.size(
{
inputs
[
1
]
}
)'
elif
self
.
type
==
'aten::view'
:
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
1
]
}
)'
else
:
raise
RuntimeError
(
'unsupported operation type: {
}'
.
format
(
self
.
type
)
)
raise
RuntimeError
(
f
'unsupported operation type:
{
self
.
type
}
?
{
self
.
_to_class_name
()
}
'
)
class
TensorFlowOperation
(
Operation
):
def
_to_class_name
(
self
)
->
str
:
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
7d1acfbd
nni/retiarii/operation_def/torch_op_def.py
View file @
7d1acfbd
from
..operation
import
PyTorchOperation
class
relu
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
return
''
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
= nn.functional.relu(
{
inputs
[
0
]
}
)'
class
Flatten
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
...
...
@@ -9,6 +17,14 @@ class Flatten(PyTorchOperation):
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
=
{
inputs
[
0
]
}
.view(
{
inputs
[
0
]
}
.size(0), -1)'
class
ToDevice
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
return
''
def
to_forward_code
(
self
,
field
,
output
,
inputs
)
->
str
:
assert
len
(
inputs
)
==
1
return
f
"
{
output
}
=
{
inputs
[
0
]
}
.to('
{
self
.
parameters
[
'device'
]
}
')"
class
Dense
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
...
...
nni/retiarii/trainer/__init__.py
View file @
7d1acfbd
from
.interface
import
BaseTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
from
.pytorch
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
nni/retiarii/trainer/pytorch/__init__.py
View file @
7d1acfbd
from
.base
import
PyTorchImageClassificationTrainer
from
.base
import
PyTorchImageClassificationTrainer
,
PyTorchMultiModelTrainer
from
.darts
import
DartsTrainer
from
.enas
import
EnasTrainer
from
.proxyless
import
ProxylessTrainer
...
...
nni/retiarii/trainer/pytorch/base.py
View file @
7d1acfbd
...
...
@@ -85,11 +85,13 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
_dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
model
.
parameters
(),
**
(
optimizer_kwargs
or
{}))
self
.
_trainer_kwargs
=
trainer_kwargs
or
{
'max_epochs'
:
10
}
# TODO: we will need at least two (maybe three) data loaders in future.
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_dataloader
=
DataLoader
(
self
.
_dataset
,
**
(
dataloader_kwargs
or
{}))
def
_accuracy
(
self
,
input
,
target
):
_
,
predict
=
torch
.
max
(
input
.
data
,
1
)
...
...
@@ -97,18 +99,32 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
return
correct
/
input
.
size
(
0
)
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
x
,
y
=
batch
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
y_hat
=
self
.
model
(
x
)
x
,
y
=
x
.
cuda
(
torch
.
device
(
'cuda:0'
)),
y
.
cuda
(
torch
.
device
(
'cuda:0'
))
return
x
,
y
def
training_step_after_model
(
self
,
x
,
y
,
y_hat
):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
x
,
y
=
batch
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
y_hat
=
self
.
model
(
x
)
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
...
...
@@ -126,9 +142,120 @@ class PyTorchImageClassificationTrainer(BaseTrainer):
def
_train
(
self
):
for
i
,
batch
in
enumerate
(
self
.
_dataloader
):
self
.
training_step
(
batch
,
i
)
loss
=
self
.
training_step
(
batch
,
i
)
loss
.
backward
()
def
fit
(
self
)
->
None
:
for
_
in
range
(
self
.
_trainer_kwargs
[
'max_epochs'
]):
self
.
_train
()
nni
.
report_final_result
(
self
.
_validate
()[
'val_acc'
])
# assuming val_acc here
# assuming val_acc here
nni
.
report_final_result
(
self
.
_validate
()[
'val_acc'
])
class
PyTorchMultiModelTrainer
(
BaseTrainer
):
def
__init__
(
self
,
multi_model
,
kwargs
=
[]):
self
.
multi_model
=
multi_model
self
.
kwargs
=
kwargs
self
.
_dataloaders
=
[]
self
.
_datasets
=
[]
self
.
_optimizers
=
[]
self
.
_trainers
=
[]
self
.
_loss_fn
=
nn
.
CrossEntropyLoss
()
self
.
max_steps
=
None
if
'max_steps'
in
self
.
kwargs
:
self
.
max_steps
=
self
.
kwargs
[
'max_steps'
]
for
m
in
self
.
kwargs
[
'model_kwargs'
]:
if
m
[
'use_input'
]:
dataset_cls
=
m
[
'dataset_cls'
]
dataset_kwargs
=
m
[
'dataset_kwargs'
]
dataloader_kwargs
=
m
[
'dataloader_kwargs'
]
dataset
=
getattr
(
datasets
,
dataset_cls
)(
transform
=
get_default_transform
(
dataset_cls
),
**
(
dataset_kwargs
or
{}))
dataloader
=
DataLoader
(
dataset
,
**
(
dataloader_kwargs
or
{}))
self
.
_datasets
.
append
(
dataset
)
self
.
_dataloaders
.
append
(
dataloader
)
if
m
[
'use_output'
]:
optimizer_cls
=
m
[
'optimizer_cls'
]
optimizer_kwargs
=
m
[
'optimizer_kwargs'
]
m_header
=
f
"M_
{
m
[
'model_id'
]
}
"
one_model_params
=
[]
for
name
,
param
in
multi_model
.
named_parameters
():
name_prefix
=
'_'
.
join
(
name
.
split
(
'_'
)[:
2
])
if
m_header
==
name_prefix
:
one_model_params
.
append
(
param
)
optimizer
=
getattr
(
torch
.
optim
,
optimizer_cls
)(
one_model_params
,
**
(
optimizer_kwargs
or
{}))
self
.
_optimizers
.
append
(
optimizer
)
def
fit
(
self
)
->
None
:
torch
.
autograd
.
set_detect_anomaly
(
True
)
max_epochs
=
max
([
x
[
'trainer_kwargs'
][
'max_epochs'
]
for
x
in
self
.
kwargs
[
'model_kwargs'
]])
for
_
in
range
(
max_epochs
):
self
.
_train
()
def
_train
(
self
):
for
batch_idx
,
multi_model_batch
in
enumerate
(
zip
(
*
self
.
_dataloaders
)):
for
opt
in
self
.
_optimizers
:
opt
.
zero_grad
()
xs
=
[]
ys
=
[]
for
idx
,
batch
in
enumerate
(
multi_model_batch
):
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
,
f
'cuda:
{
idx
}
'
)
xs
.
append
(
x
)
ys
.
append
(
y
)
y_hats
=
self
.
multi_model
(
*
xs
)
if
len
(
ys
)
!=
len
(
xs
):
raise
ValueError
(
'len(ys) should be equal to len(xs)'
)
losses
=
[]
report_loss
=
{}
for
output_idx
,
yhat
in
enumerate
(
y_hats
):
if
len
(
ys
)
==
len
(
y_hats
):
loss
=
self
.
training_step_after_model
(
xs
[
output_idx
],
ys
[
output_idx
],
yhat
)
elif
len
(
ys
)
==
1
:
loss
=
self
.
training_step_after_model
(
xs
[
0
],
ys
[
0
].
to
(
yhat
.
get_device
()),
yhat
)
else
:
raise
ValueError
(
'len(ys) should be either 1 or len(y_hats)'
)
losses
.
append
(
loss
.
to
(
"cuda:0"
))
report_loss
[
self
.
kwargs
[
'model_kwargs'
][
output_idx
][
'model_id'
]]
=
loss
.
item
()
summed_loss
=
sum
(
losses
)
summed_loss
.
backward
()
for
opt
in
self
.
_optimizers
:
opt
.
step
()
if
batch_idx
%
50
==
0
:
nni
.
report_intermediate_result
(
report_loss
)
if
self
.
max_steps
and
batch_idx
>=
self
.
max_steps
:
return
def
training_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
training_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
training_step_after_model
(
x
,
y
,
y_hat
)
def
training_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
,
device
=
None
):
x
,
y
=
batch
if
device
:
x
,
y
=
x
.
cuda
(
torch
.
device
(
device
)),
y
.
cuda
(
torch
.
device
(
device
))
return
x
,
y
def
training_step_after_model
(
self
,
x
,
y
,
y_hat
):
loss
=
self
.
_loss_fn
(
y_hat
,
y
)
return
loss
def
validation_step
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
)
->
Dict
[
str
,
Any
]:
x
,
y
=
self
.
validation_step_before_model
(
batch
,
batch_idx
)
y_hat
=
self
.
model
(
x
)
return
self
.
validation_step_after_model
(
x
,
y
,
y_hat
)
def
validation_step_before_model
(
self
,
batch
:
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
],
batch_idx
:
int
):
x
,
y
=
batch
if
self
.
_use_cuda
:
x
,
y
=
x
.
cuda
(),
y
.
cuda
()
return
x
,
y
def
validation_step_after_model
(
self
,
x
,
y
,
y_hat
):
acc
=
self
.
_accuracy
(
y_hat
,
y
)
return
{
'val_acc'
:
acc
}
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment