Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
60b2a7a3
Commit
60b2a7a3
authored
Nov 05, 2020
by
Yuge Zhang
Browse files
Merge branch 'dev-retiarii' of
https://github.com/microsoft/nni
into dev-retiarii
parents
d6791c2b
bcb7633e
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
121 additions
and
111 deletions
+121
-111
nni/retiarii/graph.py
nni/retiarii/graph.py
+66
-57
nni/retiarii/mutator.py
nni/retiarii/mutator.py
+11
-8
nni/retiarii/operation.py
nni/retiarii/operation.py
+16
-19
nni/retiarii/operation_def/tf_op_def.py
nni/retiarii/operation_def/tf_op_def.py
+4
-3
nni/retiarii/operation_def/torch_op_def.py
nni/retiarii/operation_def/torch_op_def.py
+0
-8
test/ut/retiarii/mnist-tensorflow.json
test/ut/retiarii/mnist-tensorflow.json
+12
-7
test/ut/retiarii/test_graph.py
test/ut/retiarii/test_graph.py
+8
-2
test/ut/retiarii/test_mutator.py
test/ut/retiarii/test_mutator.py
+4
-7
No files found.
nni/retiarii/graph.py
View file @
60b2a7a3
"""
Classes related to Graph IR, except `Oper
ation
`
.
Model represent
ation.
"""
from
__future__
import
annotations
import
copy
import
json
from
enum
import
Enum
from
typing
import
*
import
json
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
overload
)
from
.operation
import
Cell
,
Operation
,
_PseudoOperation
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
MetricData
=
NewType
(
'MetricData'
,
Any
)
MetricData
=
Any
"""
Graph metrics like loss, accuracy, etc.
Maybe we can assume this is a single float number for first iteration.
#
Maybe we can assume this is a single float number for first iteration.
"""
...
...
@@ -36,7 +34,7 @@ class TrainingConfig:
Trainer keyword arguments
"""
def
__init__
(
self
,
module
:
str
,
kwargs
:
Dict
[
str
,
a
ny
]):
def
__init__
(
self
,
module
:
str
,
kwargs
:
Dict
[
str
,
A
ny
]):
self
.
module
=
module
self
.
kwargs
=
kwargs
...
...
@@ -44,7 +42,7 @@ class TrainingConfig:
return
f
'TrainingConfig(module=
{
self
.
module
}
, kwargs=
{
self
.
kwargs
}
)'
@
staticmethod
def
_load
(
ir
:
Any
)
->
TrainingConfig
:
def
_load
(
ir
:
Any
)
->
'
TrainingConfig
'
:
return
TrainingConfig
(
ir
[
'module'
],
ir
.
get
(
'kwargs'
,
{}))
def
_dump
(
self
)
->
Any
:
...
...
@@ -56,15 +54,14 @@ class TrainingConfig:
class
Model
:
"""
Top-level structure of graph IR.
In execution engine's perspective, this is a trainable neural network model.
In mutator's perspective, this is a sandbox for a round of mutation.
Represents a neural network model.
Once a round of mutation starts, a sandbox is created and all mutating operations will happen inside.
When mutation is complete, the sandbox will be frozen to a trainable model.
Then the strategy will submit model to execution engine for training.
The model will record its metrics once trained.
During mutation, one `Model` object is created for each trainable snapshot.
For example, consider a mutator that insert a node at an edge for each iteration.
In one iteration, the mutator invokes 4 primitives: add node, remove edge, add edge to head, add edge to tail.
These 4 primitives operates in one `Model` object.
When they are all done the model will be set to "frozen" (trainable) status and be submitted to execution engine.
And then a new iteration starts, and a new `Model` object is created by forking last model.
Attributes
----------
...
...
@@ -104,17 +101,17 @@ class Model:
self
.
metric
:
Optional
[
MetricData
]
=
None
self
.
intermediate_metrics
:
List
[
MetricData
]
=
[]
self
.
_last_uid
:
int
=
0
self
.
_last_uid
:
int
=
0
# FIXME: this should be global, not model-wise
def
__repr__
(
self
):
return
f
'Model(model_id=
{
self
.
model_id
}
, status=
{
self
.
status
}
, graphs=
{
list
(
self
.
graphs
.
keys
())
}
, '
+
\
f
'training_config=
{
self
.
training_config
}
, metric=
{
self
.
metric
}
, intermediate_metrics=
{
self
.
intermediate_metrics
}
)'
@
property
def
root_graph
(
self
)
->
Graph
:
def
root_graph
(
self
)
->
'
Graph
'
:
return
self
.
graphs
[
self
.
_root_graph_name
]
def
fork
(
self
)
->
Model
:
def
fork
(
self
)
->
'
Model
'
:
"""
Create a new model which has same topology, names, and IDs to current one.
...
...
@@ -136,17 +133,17 @@ class Model:
return
self
.
_last_uid
@
staticmethod
def
_load
(
ir
:
Any
)
->
Model
:
def
_load
(
ir
:
Any
)
->
'
Model
'
:
model
=
Model
(
_internal
=
True
)
for
graph_name
,
graph_data
in
ir
.
items
():
if
graph_name
!=
'_training_config'
:
Graph
.
_load
(
model
,
graph_name
,
graph_data
).
_register
()
#
model.training_config = TrainingConfig._load(ir['_training_config'])
model
.
training_config
=
TrainingConfig
.
_load
(
ir
[
'_training_config'
])
return
model
def
_dump
(
self
)
->
Any
:
ret
=
{
name
:
graph
.
_dump
()
for
name
,
graph
in
self
.
graphs
.
items
()}
#
ret['_training_config'] = self.training_config._dump()
ret
[
'_training_config'
]
=
self
.
training_config
.
_dump
()
return
ret
...
...
@@ -227,41 +224,45 @@ class Graph:
f
'output_names=
{
self
.
output_names
}
, num_hidden_nodes=
{
len
(
self
.
hidden_nodes
)
}
, num_edges=
{
len
(
self
.
edges
)
}
)'
@
property
def
nodes
(
self
)
->
List
[
Node
]:
def
nodes
(
self
)
->
List
[
'
Node
'
]:
return
[
self
.
input_node
,
self
.
output_node
]
+
self
.
hidden_nodes
# mutation
def
add_node
(
self
,
type
:
Union
[
Operation
,
str
],
**
parameters
)
->
Node
:
if
isinstance
(
type
,
Operation
):
assert
not
parameters
op
=
type
@
overload
def
add_node
(
self
,
operation
:
Operation
)
->
'Node'
:
...
@
overload
def
add_node
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{})
->
'Node'
:
...
def
add_node
(
self
,
operation_or_type
,
parameters
=
{}):
if
isinstance
(
operation_or_type
,
Operation
):
op
=
operation_or_type
else
:
op
=
Operation
.
new
(
type
,
**
parameters
)
op
=
Operation
.
new
(
operation_or_
type
,
parameters
)
return
Node
(
self
,
self
.
model
.
_uid
(),
None
,
op
,
_internal
=
True
).
_register
()
# mutation
def
add_edge
(
self
,
head
:
Tuple
[
Node
,
Optional
[
int
]],
tail
:
Tuple
[
Node
,
Optional
[
int
]])
->
Edge
:
def
add_edge
(
self
,
head
:
Tuple
[
'
Node
'
,
Optional
[
int
]],
tail
:
Tuple
[
'
Node
'
,
Optional
[
int
]])
->
'
Edge
'
:
assert
head
[
0
].
graph
is
self
and
tail
[
0
].
graph
is
self
return
Edge
(
head
,
tail
).
_register
()
def
get_node_by_name
(
self
,
name
:
str
)
->
Optional
[
Node
]:
def
get_node_by_name
(
self
,
name
:
str
)
->
Optional
[
'
Node
'
]:
"""
Returns the node which has specified name; or returns `None` if no node has this name.
"""
found
=
[
node
for
node
in
self
.
nodes
if
node
.
name
==
name
]
return
found
[
0
]
if
found
else
None
def
get_nodes_by_type
(
self
,
operation_type
:
str
)
->
List
[
Node
]:
def
get_nodes_by_type
(
self
,
operation_type
:
str
)
->
List
[
'
Node
'
]:
"""
Returns nodes whose operation is specified typed.
"""
return
[
node
for
node
in
self
.
hidden_nodes
if
node
.
operation
.
type
==
operation_type
]
def
topo_sort
(
self
)
->
List
[
Node
]:
# TODO
def
topo_sort
(
self
)
->
List
[
'
Node
'
]:
# TODO
...
def
fork
(
self
)
->
Graph
:
def
fork
(
self
)
->
'
Graph
'
:
"""
Fork the model and returns corresponding graph in new model.
This shortcut might be helpful because many algorithms only cares about "stem" subgraph instead of whole model.
...
...
@@ -271,7 +272,7 @@ class Graph:
def
__eq__
(
self
,
other
:
object
)
->
bool
:
return
self
is
other
def
_fork_to
(
self
,
model
:
Model
)
->
Graph
:
def
_fork_to
(
self
,
model
:
Model
)
->
'
Graph
'
:
new_graph
=
Graph
(
model
,
self
.
id
,
self
.
name
,
_internal
=
True
).
_register
()
new_graph
.
input_names
=
self
.
input_names
new_graph
.
output_names
=
self
.
output_names
...
...
@@ -288,7 +289,7 @@ class Graph:
return
new_graph
def
_copy
(
self
)
->
Graph
:
def
_copy
(
self
)
->
'
Graph
'
:
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph
=
Graph
(
self
.
model
,
self
.
model
.
_uid
(),
_internal
=
True
).
_register
()
...
...
@@ -308,12 +309,12 @@ class Graph:
return
new_graph
def
_register
(
self
)
->
Graph
:
def
_register
(
self
)
->
'
Graph
'
:
self
.
model
.
graphs
[
self
.
name
]
=
self
return
self
@
staticmethod
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
Graph
:
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'
Graph
'
:
graph
=
Graph
(
model
,
model
.
_uid
(),
name
,
_internal
=
True
)
graph
.
input_names
=
ir
.
get
(
'inputs'
)
graph
.
output_names
=
ir
.
get
(
'outputs'
)
...
...
@@ -381,19 +382,19 @@ class Node:
return
f
'Node(id=
{
self
.
id
}
, name=
{
self
.
name
}
, operation=
{
self
.
operation
}
)'
@
property
def
predecessors
(
self
)
->
List
[
Node
]:
def
predecessors
(
self
)
->
List
[
'
Node
'
]:
return
sorted
(
set
(
edge
.
head
for
edge
in
self
.
incoming_edges
),
key
=
(
lambda
node
:
node
.
id
))
@
property
def
successors
(
self
)
->
List
[
Node
]:
def
successors
(
self
)
->
List
[
'
Node
'
]:
return
sorted
(
set
(
edge
.
tail
for
edge
in
self
.
outgoing_edges
),
key
=
(
lambda
node
:
node
.
id
))
@
property
def
incoming_edges
(
self
)
->
List
[
Edge
]:
def
incoming_edges
(
self
)
->
List
[
'
Edge
'
]:
return
[
edge
for
edge
in
self
.
graph
.
edges
if
edge
.
tail
is
self
]
@
property
def
outgoing_edges
(
self
)
->
List
[
Edge
]:
def
outgoing_edges
(
self
)
->
List
[
'
Edge
'
]:
return
[
edge
for
edge
in
self
.
graph
.
edges
if
edge
.
head
is
self
]
@
property
...
...
@@ -403,12 +404,16 @@ class Node:
# mutation
def
update_operation
(
self
,
type
:
Union
[
Operation
,
str
],
**
parameters
)
->
None
:
if
isinstance
(
type
,
Operation
):
assert
not
parameters
self
.
operation
=
type
@
overload
def
update_operation
(
self
,
operation
:
Operation
)
->
None
:
...
@
overload
def
update_operation
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{})
->
None
:
...
def
update_operation
(
self
,
operation_or_type
,
parameters
=
{}):
if
isinstance
(
operation_or_type
,
Operation
):
self
.
operation
=
operation_or_type
else
:
self
.
operation
=
Operation
.
new
(
type
,
**
parameters
)
self
.
operation
=
Operation
.
new
(
operation_or_
type
,
parameters
)
# mutation
def
remove
(
self
)
->
None
:
...
...
@@ -422,26 +427,29 @@ class Node:
Duplicate the cell template and let this node reference to newly created copy.
"""
new_cell
=
self
.
cell
.
_copy
().
_register
()
self
.
operation
=
Operation
.
new
(
'_cell'
,
cell
=
new_cell
.
name
)
self
.
operation
=
Cell
(
new_cell
.
name
)
return
new_cell
def
__eq__
(
self
,
other
:
object
)
->
bool
:
return
self
is
other
def
_register
(
self
)
->
Node
:
def
_register
(
self
)
->
'
Node
'
:
self
.
graph
.
hidden_nodes
.
append
(
self
)
return
self
@
staticmethod
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
Node
:
i
r
=
dict
(
ir
)
if
'type'
not
in
ir
and
'cell'
in
ir
:
ir
[
'type'
]
=
'_cell'
op
=
Operation
.
new
(
**
ir
)
def
_load
(
graph
:
Graph
,
name
:
str
,
ir
:
Any
)
->
'
Node
'
:
i
f
ir
[
'type'
]
==
'_cell'
:
op
=
Cell
(
ir
[
'cell'
],
ir
.
get
(
'parameters'
,
{}))
else
:
op
=
Operation
.
new
(
ir
[
'type'
],
ir
.
get
(
'parameters'
,
{})
)
return
Node
(
graph
,
graph
.
model
.
_uid
(),
name
,
op
)
def
_dump
(
self
)
->
Any
:
return
{
'type'
:
self
.
operation
.
type
,
**
self
.
operation
.
parameters
}
ret
=
{
'type'
:
self
.
operation
.
type
,
'parameters'
:
self
.
operation
.
parameters
}
if
isinstance
(
self
.
operation
,
Cell
):
ret
[
'cell'
]
=
self
.
operation
.
cell_name
return
ret
class
Edge
:
...
...
@@ -499,14 +507,15 @@ class Edge:
def
remove
(
self
)
->
None
:
self
.
graph
.
edges
.
remove
(
self
)
def
_register
(
self
)
->
Edge
:
def
_register
(
self
)
->
'
Edge
'
:
self
.
graph
.
edges
.
append
(
self
)
return
self
@
staticmethod
def
_load
(
graph
:
Graph
,
ir
:
Any
)
->
Edge
:
def
_load
(
graph
:
Graph
,
ir
:
Any
)
->
'
Edge
'
:
head
=
graph
.
get_node_by_name
(
ir
[
'head'
][
0
])
tail
=
graph
.
get_node_by_name
(
ir
[
'tail'
][
0
])
assert
head
is
not
None
and
tail
is
not
None
return
Edge
((
head
,
ir
[
'head'
][
1
]),
(
tail
,
ir
[
'tail'
][
1
]),
_internal
=
True
)
def
_dump
(
self
)
->
Any
:
...
...
nni/retiarii/mutator.py
View file @
60b2a7a3
from
__future__
import
annota
tion
s
from
typing
import
*
from
.graph
import
*
from
typing
import
(
Any
,
Iterable
,
List
,
Op
tion
al
)
from
.graph
import
Model
__all__
=
[
'Sampler'
,
'Mutator'
]
Choice
=
NewType
(
'Choice'
,
Any
)
Choice
=
Any
class
Sampler
:
"""
Handles `Mutator.choice()` calls.
"""
def
choice
(
self
,
candidates
:
List
[
Choice
],
mutator
:
Mutator
,
model
:
Model
,
index
:
int
)
->
Choice
:
def
choice
(
self
,
candidates
:
List
[
Choice
],
mutator
:
'
Mutator
'
,
model
:
Model
,
index
:
int
)
->
Choice
:
raise
NotImplementedError
()
def
mutation_start
(
self
,
mutator
:
Mutator
,
model
:
Model
)
->
None
:
def
mutation_start
(
self
,
mutator
:
'
Mutator
'
,
model
:
Model
)
->
None
:
pass
def
mutation_end
(
self
,
mutator
:
Mutator
,
model
:
Model
)
->
None
:
def
mutation_end
(
self
,
mutator
:
'
Mutator
'
,
model
:
Model
)
->
None
:
pass
...
...
@@ -44,11 +44,12 @@ class Mutator:
self
.
_cur_model
:
Optional
[
Model
]
=
None
self
.
_cur_choice_idx
:
Optional
[
int
]
=
None
def
bind_sampler
(
self
,
sampler
:
Sampler
)
->
Mutator
:
def
bind_sampler
(
self
,
sampler
:
Sampler
)
->
'
Mutator
'
:
"""
Set the sampler which will handle `Mutator.choice` calls.
"""
self
.
sampler
=
sampler
return
self
def
apply
(
self
,
model
:
Model
)
->
Model
:
"""
...
...
@@ -57,6 +58,7 @@ class Mutator:
The model will be copied before mutation and the original model will not be modified.
"""
assert
self
.
sampler
is
not
None
copy
=
model
.
fork
()
self
.
_cur_model
=
copy
self
.
_cur_choice_idx
=
0
...
...
@@ -93,6 +95,7 @@ class Mutator:
"""
Ask sampler to make a choice.
"""
assert
self
.
sampler
is
not
None
and
self
.
_cur_model
is
not
None
and
self
.
_cur_choice_idx
is
not
None
ret
=
self
.
sampler
.
choice
(
list
(
candidates
),
self
,
self
.
_cur_model
,
self
.
_cur_choice_idx
)
self
.
_cur_choice_idx
+=
1
return
ret
...
...
nni/retiarii/operation.py
View file @
60b2a7a3
from
__future__
import
annotations
from
enum
import
Enum
from
typing
import
*
from
typing
import
(
Any
,
Dict
)
from
.
import
debug_configs
__all__
=
[
'Operation'
,
'Cell'
]
class
Operation
:
"""
...
...
@@ -24,13 +24,9 @@ class Operation:
Arbitrary key-value parameters (e.g. kernel_size).
"""
def
__init__
(
self
,
type
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal_access
:
bool
=
False
):
assert
_internal_access
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type
def
__init__
(
self
,
type_name
:
str
,
parameters
:
Dict
[
str
,
Any
],
_internal
:
bool
=
False
):
assert
_internal
,
'`Operation()` is private, use `Operation.new()` instead'
self
.
type
:
str
=
type_name
self
.
parameters
:
Dict
[
str
,
Any
]
=
parameters
def
to_init_code
(
self
,
field
:
str
)
->
str
:
...
...
@@ -47,19 +43,19 @@ class Operation:
return
True
@
staticmethod
def
new
(
type
:
str
,
**
parameters
:
Any
)
->
Operation
:
if
type
==
'_cell'
:
def
new
(
type
_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}
)
->
'
Operation
'
:
if
type
_name
==
'_cell'
:
return
Cell
(
parameters
[
'cell'
])
else
:
if
debug_configs
.
framework
.
lower
()
in
(
'torch'
,
'pytorch'
):
from
.operation_def
import
torch_op_def
from
.operation_def
import
torch_op_def
# pylint: disable=unused-import
cls
=
PyTorchOperation
.
_find_subclass
(
type
)
elif
debug_configs
.
framework
.
lower
()
in
(
'tf'
,
'tensorflow'
):
from
.operation_def
import
tf_op_def
from
.operation_def
import
tf_op_def
# pylint: disable=unused-import
cls
=
TensorFlowOperation
.
_find_subclass
(
type
)
else
:
raise
ValueError
(
f
'Unsupported framework:
{
debug_configs
.
framework
}
'
)
return
cls
(
type
,
parameters
,
_internal
_access
=
True
)
return
cls
(
type
_name
,
parameters
,
_internal
=
True
)
@
classmethod
def
_find_subclass
(
cls
,
subclass_name
):
...
...
@@ -120,12 +116,13 @@ class Cell(Operation):
framework
No real usage. Exists for compatibility with base class.
"""
def
__init__
(
self
,
cell_name
:
str
):
def
__init__
(
self
,
cell_name
:
str
,
parameters
:
Dict
[
str
,
Any
]
=
{}
):
self
.
type
=
'_cell'
self
.
parameters
=
{
'cell'
:
cell_name
}
self
.
cell_name
=
cell_name
self
.
parameters
=
parameters
def
to_
init_code
(
self
,
field
:
str
)
->
str
:
return
f
'
self.
{
field
}
=
{
self
.
parameters
[
"cell"
]
}
()'
def
_
to_
class_name
(
self
)
:
return
self
.
cell_name
class
_PseudoOperation
(
Operation
):
...
...
nni/retiarii/operation_def/tf_op_def.py
View file @
60b2a7a3
from
..operation
import
TensorFlowOperation
class
Conv2D
(
TensorFlowOperation
):
def
to_init_code
(
self
,
field
):
parameters
=
{
'padding'
:
'same'
,
**
parameters
}
super
().
__init__
(
type
,
parameters
,
_internal_access
)
def
__init__
(
self
,
type_name
,
parameters
,
_internal
):
if
'padding'
not
in
parameters
:
parameters
[
'padding'
]
=
'same'
super
().
__init__
(
type_name
,
parameters
,
_internal
)
nni/retiarii/operation_def/torch_op_def.py
View file @
60b2a7a3
from
..operation
import
PyTorchOperation
class
relu
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
return
''
def
to_forward_code
(
self
,
field
,
output
,
*
inputs
)
->
str
:
assert
len
(
inputs
)
==
1
return
f
'
{
output
}
= nn.functional.relu(
{
inputs
[
0
]
}
)'
class
Flatten
(
PyTorchOperation
):
def
to_init_code
(
self
,
field
):
...
...
test/ut/retiarii/mnist-tensorflow.json
View file @
60b2a7a3
...
...
@@ -4,10 +4,10 @@
"outputs"
:
[
"metric"
],
"nodes"
:
{
"stem"
:
{
"cell"
:
"stem"
},
"stem"
:
{
"type"
:
"_cell"
,
"cell"
:
"stem"
},
"flatten"
:
{
"type"
:
"Flatten"
},
"fc1"
:
{
"type"
:
"Dense"
,
"units"
:
1024
,
"activation"
:
"relu"
},
"fc2"
:
{
"type"
:
"Dense"
,
"units"
:
10
},
"fc1"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
1024
,
"activation"
:
"relu"
}
}
,
"fc2"
:
{
"type"
:
"Dense"
,
"parameters"
:
{
"units"
:
10
}
}
,
"softmax"
:
{
"type"
:
"Softmax"
}
},
...
...
@@ -23,10 +23,10 @@
"stem"
:
{
"nodes"
:
{
"conv1"
:
{
"type"
:
"Conv2D"
,
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
},
"pool1"
:
{
"type"
:
"MaxPool2D"
,
"pool_size"
:
2
},
"conv2"
:
{
"type"
:
"Conv2D"
,
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
},
"pool2"
:
{
"type"
:
"MaxPool2D"
,
"pool_size"
:
2
}
"conv1"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
32
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
}
,
"pool1"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
}
,
"conv2"
:
{
"type"
:
"Conv2D"
,
"parameters"
:
{
"filters"
:
64
,
"kernel_size"
:
5
,
"activation"
:
"relu"
}
}
,
"pool2"
:
{
"type"
:
"MaxPool2D"
,
"parameters"
:
{
"pool_size"
:
2
}
}
},
"edges"
:
[
...
...
@@ -36,5 +36,10 @@
{
"head"
:
[
"conv2"
,
null
],
"tail"
:
[
"pool2"
,
null
]},
{
"head"
:
[
"pool2"
,
null
],
"tail"
:
[
"_outputs"
,
0
]}
]
},
"_training_config"
:
{
"module"
:
"_debug_no_trainer"
,
"kwargs"
:
{}
}
}
test/ut/retiarii/test_graph.py
View file @
60b2a7a3
...
...
@@ -23,13 +23,19 @@ def _test_file(json_path):
# add default values to JSON, so we can compare with `==`
for
graph_name
,
graph
in
orig_ir
.
items
():
if
graph_name
==
'_training_config'
:
continue
if
'inputs'
not
in
graph
:
graph
[
'inputs'
]
=
None
if
'outputs'
not
in
graph
:
graph
[
'outputs'
]
=
None
for
node_name
,
node
in
graph
[
'nodes'
].
items
():
if
'type'
not
in
node
and
'cell'
in
node
:
node
[
'type'
]
=
'_cell'
if
'parameters'
not
in
node
:
node
[
'parameters'
]
=
{}
# debug output
#json.dump(orig_ir, open('_orig.json', 'w'), indent=4)
#json.dump(dump_ir, open('_dump.json', 'w'), indent=4)
assert
orig_ir
==
dump_ir
...
...
test/ut/retiarii/test_mutator.py
View file @
60b2a7a3
...
...
@@ -8,6 +8,10 @@ from nni.retiarii import *
import
nni.retiarii.debug_configs
nni
.
retiarii
.
debug_configs
.
framework
=
'tensorflow'
max_pool
=
Operation
.
new
(
'MaxPool2D'
,
{
'pool_size'
:
2
})
avg_pool
=
Operation
.
new
(
'AveragePooling2D'
,
{
'pool_size'
:
2
})
global_pool
=
Operation
.
new
(
'GlobalAveragePooling2D'
)
class
DebugSampler
(
Sampler
):
def
__init__
(
self
):
...
...
@@ -22,9 +26,6 @@ class DebugSampler(Sampler):
class
DebugMutator
(
Mutator
):
def
mutate
(
self
,
model
):
max_pool
=
Operation
.
new
(
'MaxPool2D'
,
pool_size
=
2
)
avg_pool
=
Operation
.
new
(
'AveragePooling2D'
,
pool_size
=
2
)
global_pool
=
Operation
.
new
(
'GlobalAveragePooling2D'
)
ops
=
[
max_pool
,
avg_pool
,
global_pool
]
pool1
=
model
.
graphs
[
'stem'
].
get_node_by_name
(
'pool1'
)
...
...
@@ -67,10 +68,6 @@ def _get_pools(model):
return
pool1
,
pool2
max_pool
=
Operation
.
new
(
type
=
'MaxPool2D'
,
pool_size
=
2
)
avg_pool
=
Operation
.
new
(
type
=
'AveragePooling2D'
,
pool_size
=
2
)
global_pool
=
Operation
.
new
(
type
=
'GlobalAveragePooling2D'
)
if
__name__
==
'__main__'
:
test_dry_run
()
test_mutation
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment