Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
b3cdee85
Unverified
Commit
b3cdee85
authored
Dec 14, 2020
by
liuzhe-lz
Committed by
GitHub
Dec 14, 2020
Browse files
fix uid duplicate and add type hint alias for edge endpoint (#3188)
parent
192a807b
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
25 additions
and
26 deletions
+25
-26
nni/retiarii/graph.py
nni/retiarii/graph.py
+17
-24
nni/retiarii/utils.py
nni/retiarii/utils.py
+8
-2
No files found.
nni/retiarii/graph.py
View file @
b3cdee85
...
@@ -9,15 +9,19 @@ from collections import defaultdict
...
@@ -9,15 +9,19 @@ from collections import defaultdict
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
typing
import
(
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Union
,
overload
)
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.operation
import
Cell
,
Operation
,
_IOPseudoOperation
from
.utils
import
uid
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
__all__
=
[
'Model'
,
'ModelStatus'
,
'Graph'
,
'Node'
,
'Edge'
,
'IllegalGraphError'
,
'MetricData'
]
MetricData
=
Any
MetricData
=
Any
"""
"""
Graph metrics like loss, accuracy, etc.
Type hint for graph metrics (loss, accuracy, etc).
"""
# Maybe we can assume this is a single float number for first iteration.
EdgeEndpoint
=
Tuple
[
'Node'
,
Optional
[
int
]]
"""
Type hint for edge's endpoint. The int indicates nodes' order.
"""
"""
...
@@ -88,12 +92,10 @@ class Model:
...
@@ -88,12 +92,10 @@ class Model:
intermediate_metrics
intermediate_metrics
Intermediate training metrics. If the model is not trained, it's an empty list.
Intermediate training metrics. If the model is not trained, it's an empty list.
"""
"""
_cur_model_id
=
0
def
__init__
(
self
,
_internal
=
False
):
def
__init__
(
self
,
_internal
=
False
):
assert
_internal
,
'`Model()` is private, use `model.fork()` instead'
assert
_internal
,
'`Model()` is private, use `model.fork()` instead'
Model
.
_cur_model_id
+=
1
self
.
model_id
:
int
=
uid
(
'model'
)
self
.
model_id
=
Model
.
_cur_model_id
self
.
status
:
ModelStatus
=
ModelStatus
.
Mutating
self
.
status
:
ModelStatus
=
ModelStatus
.
Mutating
...
@@ -106,8 +108,6 @@ class Model:
...
@@ -106,8 +108,6 @@ class Model:
self
.
metric
:
Optional
[
MetricData
]
=
None
self
.
metric
:
Optional
[
MetricData
]
=
None
self
.
intermediate_metrics
:
List
[
MetricData
]
=
[]
self
.
intermediate_metrics
:
List
[
MetricData
]
=
[]
self
.
_last_uid
:
int
=
0
# FIXME: this should be global, not model-wise
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
'Model(model_id=
{
self
.
model_id
}
, status=
{
self
.
status
}
, graphs=
{
list
(
self
.
graphs
.
keys
())
}
, '
+
\
return
f
'Model(model_id=
{
self
.
model_id
}
, status=
{
self
.
status
}
, graphs=
{
list
(
self
.
graphs
.
keys
())
}
, '
+
\
f
'training_config=
{
self
.
training_config
}
, metric=
{
self
.
metric
}
, intermediate_metrics=
{
self
.
intermediate_metrics
}
)'
f
'training_config=
{
self
.
training_config
}
, metric=
{
self
.
metric
}
, intermediate_metrics=
{
self
.
intermediate_metrics
}
)'
...
@@ -130,13 +130,8 @@ class Model:
...
@@ -130,13 +130,8 @@ class Model:
new_model
.
graphs
=
{
name
:
graph
.
_fork_to
(
new_model
)
for
name
,
graph
in
self
.
graphs
.
items
()}
new_model
.
graphs
=
{
name
:
graph
.
_fork_to
(
new_model
)
for
name
,
graph
in
self
.
graphs
.
items
()}
new_model
.
training_config
=
copy
.
deepcopy
(
self
.
training_config
)
new_model
.
training_config
=
copy
.
deepcopy
(
self
.
training_config
)
new_model
.
history
=
self
.
history
+
[
self
]
new_model
.
history
=
self
.
history
+
[
self
]
new_model
.
_last_uid
=
self
.
_last_uid
return
new_model
return
new_model
def
_uid
(
self
)
->
int
:
self
.
_last_uid
+=
1
return
self
.
_last_uid
@
staticmethod
@
staticmethod
def
_load
(
ir
:
Any
)
->
'Model'
:
def
_load
(
ir
:
Any
)
->
'Model'
:
model
=
Model
(
_internal
=
True
)
model
=
Model
(
_internal
=
True
)
...
@@ -295,7 +290,7 @@ class Graph:
...
@@ -295,7 +290,7 @@ class Graph:
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
return
Node
(
self
,
self
.
model
.
_
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
return
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
@
overload
@
overload
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
def
insert_node_on_edge
(
self
,
edge
:
'Edge'
,
name
:
str
,
operation
:
Operation
)
->
'Node'
:
...
...
@@ -307,7 +302,7 @@ class Graph:
...
@@ -307,7 +302,7 @@ class Graph:
op
=
operation_or_type
op
=
operation_or_type
else
:
else
:
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
op
=
Operation
.
new
(
operation_or_type
,
parameters
,
name
)
new_node
=
Node
(
self
,
self
.
model
.
_
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
new_node
=
Node
(
self
,
uid
(),
name
,
op
,
_internal
=
True
).
_register
()
# update edges
# update edges
self
.
add_edge
((
edge
.
head
,
edge
.
head_slot
),
(
new_node
,
None
))
self
.
add_edge
((
edge
.
head
,
edge
.
head_slot
),
(
new_node
,
None
))
self
.
add_edge
((
new_node
,
None
),
(
edge
.
tail
,
edge
.
tail_slot
))
self
.
add_edge
((
new_node
,
None
),
(
edge
.
tail
,
edge
.
tail_slot
))
...
@@ -315,7 +310,7 @@ class Graph:
...
@@ -315,7 +310,7 @@ class Graph:
return
new_node
return
new_node
# mutation
# mutation
def
add_edge
(
self
,
head
:
Tuple
[
'Node'
,
Optional
[
int
]]
,
tail
:
Tuple
[
'Node'
,
Optional
[
int
]]
)
->
'Edge'
:
def
add_edge
(
self
,
head
:
EdgeEndpo
int
,
tail
:
EdgeEndpo
int
)
->
'Edge'
:
assert
head
[
0
].
graph
is
self
and
tail
[
0
].
graph
is
self
assert
head
[
0
].
graph
is
self
and
tail
[
0
].
graph
is
self
return
Edge
(
head
,
tail
,
_internal
=
True
).
_register
()
return
Edge
(
head
,
tail
,
_internal
=
True
).
_register
()
...
@@ -414,7 +409,7 @@ class Graph:
...
@@ -414,7 +409,7 @@ class Graph:
def
_copy
(
self
)
->
'Graph'
:
def
_copy
(
self
)
->
'Graph'
:
# Copy this graph inside the model.
# Copy this graph inside the model.
# The new graph will have identical topology, but its nodes' name and ID will be different.
# The new graph will have identical topology, but its nodes' name and ID will be different.
new_graph
=
Graph
(
self
.
model
,
self
.
model
.
_
uid
(),
_internal
=
True
).
_register
()
new_graph
=
Graph
(
self
.
model
,
uid
(),
_internal
=
True
).
_register
()
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
input_node
.
operation
.
io_names
=
self
.
input_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
output_node
.
operation
.
io_names
=
self
.
output_node
.
operation
.
io_names
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
new_graph
.
input_node
.
update_label
(
self
.
input_node
.
label
)
...
@@ -423,7 +418,7 @@ class Graph:
...
@@ -423,7 +418,7 @@ class Graph:
id_to_new_node
=
{}
# old node ID -> new node object
id_to_new_node
=
{}
# old node ID -> new node object
for
old_node
in
self
.
hidden_nodes
:
for
old_node
in
self
.
hidden_nodes
:
new_node
=
Node
(
new_graph
,
self
.
model
.
_
uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
=
Node
(
new_graph
,
uid
(),
None
,
old_node
.
operation
,
_internal
=
True
).
_register
()
new_node
.
update_label
(
old_node
.
label
)
new_node
.
update_label
(
old_node
.
label
)
id_to_new_node
[
old_node
.
id
]
=
new_node
id_to_new_node
[
old_node
.
id
]
=
new_node
...
@@ -440,7 +435,7 @@ class Graph:
...
@@ -440,7 +435,7 @@ class Graph:
@
staticmethod
@
staticmethod
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'Graph'
:
def
_load
(
model
:
Model
,
name
:
str
,
ir
:
Any
)
->
'Graph'
:
graph
=
Graph
(
model
,
model
.
_
uid
(),
name
,
_internal
=
True
)
graph
=
Graph
(
model
,
uid
(),
name
,
_internal
=
True
)
graph
.
input_node
.
operation
.
io_names
=
ir
.
get
(
'inputs'
)
graph
.
input_node
.
operation
.
io_names
=
ir
.
get
(
'inputs'
)
graph
.
output_node
.
operation
.
io_names
=
ir
.
get
(
'outputs'
)
graph
.
output_node
.
operation
.
io_names
=
ir
.
get
(
'outputs'
)
for
node_name
,
node_data
in
ir
[
'nodes'
].
items
():
for
node_name
,
node_data
in
ir
[
'nodes'
].
items
():
...
@@ -501,6 +496,8 @@ class Node:
...
@@ -501,6 +496,8 @@ class Node:
self
.
graph
:
Graph
=
graph
self
.
graph
:
Graph
=
graph
self
.
id
:
int
=
node_id
self
.
id
:
int
=
node_id
self
.
name
:
str
=
name
or
f
'_generated_
{
node_id
}
'
self
.
name
:
str
=
name
or
f
'_generated_
{
node_id
}
'
# TODO: the operation is likely to be considered editable by end-user and it will be hard to debug
# maybe we should copy it here or make Operation class immutable, in next release
self
.
operation
:
Operation
=
operation
self
.
operation
:
Operation
=
operation
self
.
label
:
str
=
None
self
.
label
:
str
=
None
...
@@ -577,7 +574,7 @@ class Node:
...
@@ -577,7 +574,7 @@ class Node:
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Cell
(
ir
[
'operation'
][
'cell_name'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
else
:
else
:
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
op
=
Operation
.
new
(
ir
[
'operation'
][
'type'
],
ir
[
'operation'
].
get
(
'parameters'
,
{}))
node
=
Node
(
graph
,
graph
.
model
.
_
uid
(),
name
,
op
)
node
=
Node
(
graph
,
uid
(),
name
,
op
)
if
'label'
in
ir
:
if
'label'
in
ir
:
node
.
update_label
(
ir
[
'label'
])
node
.
update_label
(
ir
[
'label'
])
return
node
return
node
...
@@ -626,11 +623,7 @@ class Edge:
...
@@ -626,11 +623,7 @@ class Edge:
If the node does not care about order, this can be `-1`.
If the node does not care about order, this can be `-1`.
"""
"""
def
__init__
(
def
__init__
(
self
,
head
:
EdgeEndpoint
,
tail
:
EdgeEndpoint
,
_internal
:
bool
=
False
):
self
,
head
:
Tuple
[
Node
,
Optional
[
int
]],
tail
:
Tuple
[
Node
,
Optional
[
int
]],
_internal
:
bool
=
False
):
assert
_internal
,
'`Edge()` is private'
assert
_internal
,
'`Edge()` is private'
self
.
graph
:
Graph
=
head
[
0
].
graph
self
.
graph
:
Graph
=
head
[
0
].
graph
self
.
head
:
Node
=
head
[
0
]
self
.
head
:
Node
=
head
[
0
]
...
...
nni/retiarii/utils.py
View file @
b3cdee85
from
collections
import
defaultdict
import
inspect
import
inspect
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'Any'
:
def
import_
(
target
:
str
,
allow_none
:
bool
=
False
)
->
'Any'
:
...
@@ -7,7 +8,6 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
...
@@ -7,7 +8,6 @@ def import_(target: str, allow_none: bool = False) -> 'Any':
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
module
=
__import__
(
path
,
globals
(),
locals
(),
[
identifier
])
return
getattr
(
module
,
identifier
)
return
getattr
(
module
,
identifier
)
_records
=
{}
_records
=
{}
def
get_records
():
def
get_records
():
...
@@ -83,3 +83,9 @@ def register_trainer():
...
@@ -83,3 +83,9 @@ def register_trainer():
return
m
return
m
return
_register
return
_register
_last_uid
=
defaultdict
(
int
)
def
uid
(
namespace
:
str
=
'default'
)
->
int
:
_last_uid
[
namespace
]
+=
1
return
_last_uid
[
namespace
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment