Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
c837bfc0
Unverified
Commit
c837bfc0
authored
Dec 21, 2021
by
Yuge Zhang
Committed by
GitHub
Dec 21, 2021
Browse files
Fix deepcopy of mutables (#4400)
parent
72087f8a
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
137 additions
and
63 deletions
+137
-63
nni/retiarii/nn/pytorch/api.py
nni/retiarii/nn/pytorch/api.py
+19
-27
nni/retiarii/nn/pytorch/component.py
nni/retiarii/nn/pytorch/component.py
+19
-14
nni/retiarii/nn/pytorch/nasbench101.py
nni/retiarii/nn/pytorch/nasbench101.py
+17
-20
nni/retiarii/nn/pytorch/utils.py
nni/retiarii/nn/pytorch/utils.py
+34
-2
test/ut/retiarii/test_highlevel_apis.py
test/ut/retiarii/test_highlevel_apis.py
+48
-0
No files found.
nni/retiarii/nn/pytorch/api.py
View file @
c837bfc0
...
@@ -10,14 +10,13 @@ import torch.nn as nn
...
@@ -10,14 +10,13 @@ import torch.nn as nn
from
nni.common.serializer
import
Translatable
from
nni.common.serializer
import
Translatable
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.serializer
import
basic_unit
from
nni.retiarii.utils
import
NoContextError
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
.utils
import
generate_new_label
,
get_fixed_value
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
__all__
=
[
'LayerChoice'
,
'InputChoice'
,
'ValueChoice'
,
'Placeholder'
,
'ChosenInputs'
]
class
LayerChoice
(
nn
.
Modu
le
):
class
LayerChoice
(
Mutab
le
):
"""
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
...
@@ -60,16 +59,14 @@ class LayerChoice(nn.Module):
...
@@ -60,16 +59,14 @@ class LayerChoice(nn.Module):
# FIXME: prior is designed but not supported yet
# FIXME: prior is designed but not supported yet
def
__new__
(
cls
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
@
classmethod
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
def
create_fixed_module
(
cls
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
try
:
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
chosen
=
get_fixed_value
(
label
)
chosen
=
get_fixed_value
(
label
)
if
isinstance
(
candidates
,
list
):
if
isinstance
(
candidates
,
list
):
return
candidates
[
int
(
chosen
)]
return
candidates
[
int
(
chosen
)]
else
:
else
:
return
candidates
[
chosen
]
return
candidates
[
chosen
]
except
NoContextError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
def
__init__
(
self
,
candidates
:
Union
[
Dict
[
str
,
nn
.
Module
],
List
[
nn
.
Module
]],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
...
@@ -159,7 +156,7 @@ class LayerChoice(nn.Module):
...
@@ -159,7 +156,7 @@ class LayerChoice(nn.Module):
return
f
'LayerChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
return
f
'LayerChoice(
{
self
.
candidates
}
, label=
{
repr
(
self
.
label
)
}
)'
class
InputChoice
(
nn
.
Modu
le
):
class
InputChoice
(
Mutab
le
):
"""
"""
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
Use ``reduction`` to specify how chosen inputs are reduced into one output. A few options are:
Use ``reduction`` to specify how chosen inputs are reduced into one output. A few options are:
...
@@ -185,13 +182,10 @@ class InputChoice(nn.Module):
...
@@ -185,13 +182,10 @@ class InputChoice(nn.Module):
Identifier of the input choice.
Identifier of the input choice.
"""
"""
def
__new__
(
cls
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
=
1
,
@
classmethod
reduction
:
str
=
'sum'
,
*
,
def
create_fixed_module
(
cls
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
=
1
,
reduction
:
str
=
'sum'
,
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
try
:
return
ChosenInputs
(
get_fixed_value
(
label
),
reduction
=
reduction
)
return
ChosenInputs
(
get_fixed_value
(
label
),
reduction
=
reduction
)
except
NoContextError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
=
1
,
def
__init__
(
self
,
n_candidates
:
int
,
n_chosen
:
Optional
[
int
]
=
1
,
reduction
:
str
=
'sum'
,
*
,
reduction
:
str
=
'sum'
,
*
,
...
@@ -234,7 +228,7 @@ class InputChoice(nn.Module):
...
@@ -234,7 +228,7 @@ class InputChoice(nn.Module):
f
'reduction=
{
repr
(
self
.
reduction
)
}
, label=
{
repr
(
self
.
label
)
}
)'
f
'reduction=
{
repr
(
self
.
reduction
)
}
, label=
{
repr
(
self
.
label
)
}
)'
class
ValueChoice
(
Translatable
,
nn
.
Modu
le
):
class
ValueChoice
(
Translatable
,
Mutab
le
):
"""
"""
ValueChoice is to choose one from ``candidates``.
ValueChoice is to choose one from ``candidates``.
...
@@ -302,11 +296,9 @@ class ValueChoice(Translatable, nn.Module):
...
@@ -302,11 +296,9 @@ class ValueChoice(Translatable, nn.Module):
# FIXME: prior is designed but not supported yet
# FIXME: prior is designed but not supported yet
def
__new__
(
cls
,
candidates
:
List
[
Any
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
@
classmethod
try
:
def
create_fixed_module
(
cls
,
candidates
:
List
[
Any
],
*
,
label
:
Optional
[
str
]
=
None
,
**
kwargs
):
return
get_fixed_value
(
label
)
return
get_fixed_value
(
label
)
except
NoContextError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
candidates
:
List
[
Any
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
def
__init__
(
self
,
candidates
:
List
[
Any
],
*
,
prior
:
Optional
[
List
[
float
]]
=
None
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
super
().
__init__
()
...
...
nni/retiarii/nn/pytorch/component.py
View file @
c837bfc0
...
@@ -9,14 +9,13 @@ from .api import LayerChoice, InputChoice
...
@@ -9,14 +9,13 @@ from .api import LayerChoice, InputChoice
from
.nn
import
ModuleList
from
.nn
import
ModuleList
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.nasbench101
import
NasBench101Cell
,
NasBench101Mutator
from
.utils
import
generate_new_label
,
get_fixed_value
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_value
from
...utils
import
NoContextError
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
__all__
=
[
'Repeat'
,
'Cell'
,
'NasBench101Cell'
,
'NasBench101Mutator'
,
'NasBench201Cell'
]
class
Repeat
(
nn
.
Modu
le
):
class
Repeat
(
Mutab
le
):
"""
"""
Repeat a block by a variable number of times.
Repeat a block by a variable number of times.
...
@@ -25,23 +24,29 @@ class Repeat(nn.Module):
...
@@ -25,23 +24,29 @@ class Repeat(nn.Module):
blocks : function, list of function, module or list of module
blocks : function, list of function, module or list of module
The block to be repeated. If not a list, it will be replicated into a list.
The block to be repeated. If not a list, it will be replicated into a list.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied.
If a function, it will be called (the argument is the index) to instantiate a module.
Otherwise the module will be deep-copied.
depth : int or tuple of int
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least `min` times and at most `max` times.
meaning that the block will be repeated at least `min` times and at most `max` times.
"""
"""
def
__new__
(
cls
,
blocks
:
Union
[
Callable
[[],
nn
.
Module
],
List
[
Callable
[[],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
@
classmethod
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
label
:
Optional
[
str
]
=
None
):
def
create_fixed_module
(
cls
,
try
:
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
repeat
=
get_fixed_value
(
label
)
List
[
Callable
[[
int
],
nn
.
Module
]],
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
nn
.
Module
,
except
NoContextError
:
List
[
nn
.
Module
]],
return
super
().
__new__
(
cls
)
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
repeat
=
get_fixed_value
(
label
)
return
nn
.
Sequential
(
*
cls
.
_replicate_and_instantiate
(
blocks
,
repeat
))
def
__init__
(
self
,
def
__init__
(
self
,
blocks
:
Union
[
Callable
[[],
nn
.
Module
],
List
[
Callable
[[],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
blocks
:
Union
[
Callable
[[
int
],
nn
.
Module
],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
label
:
Optional
[
str
]
=
None
):
List
[
Callable
[[
int
],
nn
.
Module
]],
nn
.
Module
,
List
[
nn
.
Module
]],
depth
:
Union
[
int
,
Tuple
[
int
,
int
]],
*
,
label
:
Optional
[
str
]
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
_label
=
generate_new_label
(
label
)
self
.
_label
=
generate_new_label
(
label
)
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
self
.
min_depth
=
depth
if
isinstance
(
depth
,
int
)
else
depth
[
0
]
...
@@ -69,7 +74,7 @@ class Repeat(nn.Module):
...
@@ -69,7 +74,7 @@ class Repeat(nn.Module):
assert
repeat
<=
len
(
blocks
),
f
'Not enough blocks to be used.
{
repeat
}
expected, only found
{
len
(
blocks
)
}
.'
assert
repeat
<=
len
(
blocks
),
f
'Not enough blocks to be used.
{
repeat
}
expected, only found
{
len
(
blocks
)
}
.'
blocks
=
blocks
[:
repeat
]
blocks
=
blocks
[:
repeat
]
if
not
isinstance
(
blocks
[
0
],
nn
.
Module
):
if
not
isinstance
(
blocks
[
0
],
nn
.
Module
):
blocks
=
[
b
()
for
b
in
blocks
]
blocks
=
[
b
(
i
)
for
i
,
b
in
enumerate
(
blocks
)
]
return
blocks
return
blocks
...
...
nni/retiarii/nn/pytorch/nasbench101.py
View file @
c837bfc0
...
@@ -6,11 +6,10 @@ import numpy as np
...
@@ -6,11 +6,10 @@ import numpy as np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.mutator
import
InvalidMutation
,
Mutator
from
nni.retiarii.graph
import
Model
from
.api
import
InputChoice
,
ValueChoice
,
LayerChoice
from
.api
import
InputChoice
,
ValueChoice
,
LayerChoice
from
.utils
import
generate_new_label
,
get_fixed_dict
from
.utils
import
Mutable
,
generate_new_label
,
get_fixed_dict
from
...mutator
import
InvalidMutation
,
Mutator
from
...graph
import
Model
from
...utils
import
NoContextError
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
...
@@ -218,7 +217,7 @@ class _NasBench101CellFixed(nn.Module):
...
@@ -218,7 +217,7 @@ class _NasBench101CellFixed(nn.Module):
return
outputs
return
outputs
class
NasBench101Cell
(
nn
.
Modu
le
):
class
NasBench101Cell
(
Mutab
le
):
"""
"""
Cell structure that is proposed in NAS-Bench-101 [nasbench101]_ .
Cell structure that is proposed in NAS-Bench-101 [nasbench101]_ .
...
@@ -289,23 +288,21 @@ class NasBench101Cell(nn.Module):
...
@@ -289,23 +288,21 @@ class NasBench101Cell(nn.Module):
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
([(
str
(
i
),
t
)
for
i
,
t
in
enumerate
(
x
)])
return
OrderedDict
(
x
)
return
OrderedDict
(
x
)
def
__new__
(
cls
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
@
classmethod
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
def
create_fixed_module
(
cls
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
max_num_nodes
:
int
=
7
,
max_num_edges
:
int
=
9
,
label
:
Optional
[
str
]
=
None
):
def
make_list
(
x
):
return
x
if
isinstance
(
x
,
list
)
else
[
x
]
def
make_list
(
x
):
return
x
if
isinstance
(
x
,
list
)
else
[
x
]
try
:
label
,
selected
=
get_fixed_dict
(
label
)
label
,
selected
=
get_fixed_dict
(
label
)
op_candidates
=
cls
.
_make_dict
(
op_candidates
)
op_candidates
=
cls
.
_make_dict
(
op_candidates
)
num_nodes
=
selected
[
f
'
{
label
}
/num_nodes'
]
num_nodes
=
selected
[
f
'
{
label
}
/num_nodes'
]
adjacency_list
=
[
make_list
(
selected
[
f
'
{
label
}
/input
{
i
}
'
])
for
i
in
range
(
1
,
num_nodes
)]
adjacency_list
=
[
make_list
(
selected
[
f
'
{
label
}
/input
{
i
}
'
])
for
i
in
range
(
1
,
num_nodes
)]
if
sum
([
len
(
e
)
for
e
in
adjacency_list
])
>
max_num_edges
:
if
sum
([
len
(
e
)
for
e
in
adjacency_list
])
>
max_num_edges
:
raise
InvalidMutation
(
f
'Expected
{
max_num_edges
}
edges, found:
{
adjacency_list
}
'
)
raise
InvalidMutation
(
f
'Expected
{
max_num_edges
}
edges, found:
{
adjacency_list
}
'
)
return
_NasBench101CellFixed
(
return
_NasBench101CellFixed
(
[
op_candidates
[
selected
[
f
'
{
label
}
/op
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
[
op_candidates
[
selected
[
f
'
{
label
}
/op
{
i
}
'
]]
for
i
in
range
(
1
,
num_nodes
-
1
)],
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
adjacency_list
,
in_features
,
out_features
,
num_nodes
,
projection
)
except
NoContextError
:
return
super
().
__new__
(
cls
)
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
def
__init__
(
self
,
op_candidates
:
Union
[
Dict
[
str
,
Callable
[[
int
],
nn
.
Module
]],
List
[
Callable
[[
int
],
nn
.
Module
]]],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
in_features
:
int
,
out_features
:
int
,
projection
:
Callable
[[
int
,
int
],
nn
.
Module
],
...
...
nni/retiarii/nn/pytorch/utils.py
View file @
c837bfc0
from
typing
import
Any
,
Optional
,
Tuple
from
typing
import
Any
,
Optional
,
Tuple
,
Union
from
nni.retiarii.utils
import
ModelNamespace
,
get_current_context
import
torch.nn
as
nn
from
nni.retiarii.utils
import
NoContextError
,
ModelNamespace
,
get_current_context
class
Mutable
(
nn
.
Module
):
"""
This is just an implementation trick for now.
In future, this could be the base class for all PyTorch mutables including layer choice, input choice, etc.
This is not considered as an interface, but rather as a base class consisting of commonly used class/instance methods.
For API developers, it's not recommended to use ``isinstance(module, Mutable)`` to check for mutable modules either,
before the design is finalized.
"""
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
not
args
and
not
kwargs
:
# this can be the case of copy/deepcopy
# attributes are assigned afterwards in __dict__
return
super
().
__new__
(
cls
)
try
:
return
cls
.
create_fixed_module
(
*
args
,
**
kwargs
)
except
NoContextError
:
return
super
().
__new__
(
cls
)
@
classmethod
def
create_fixed_module
(
cls
,
*
args
,
**
kwargs
)
->
Union
[
nn
.
Module
,
Any
]:
"""
Try to create a fixed module from fixed dict.
If the code is running in a trial, this method would succeed, and a concrete module instead of a mutable will be created.
Raises no context error if the creation failed.
"""
raise
NotImplementedError
def
generate_new_label
(
label
:
Optional
[
str
]):
def
generate_new_label
(
label
:
Optional
[
str
]):
...
...
test/ut/retiarii/test_highlevel_apis.py
View file @
c837bfc0
...
@@ -483,6 +483,54 @@ class GraphIR(unittest.TestCase):
...
@@ -483,6 +483,54 @@ class GraphIR(unittest.TestCase):
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
zeros
(
1
,
16
))
==
4
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model2
)(
torch
.
zeros
(
1
,
16
))
==
4
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model3
)(
torch
.
zeros
(
1
,
16
))
==
5
).
all
())
self
.
assertTrue
((
self
.
_get_converted_pytorch_model
(
model3
)(
torch
.
zeros
(
1
,
16
))
==
5
).
all
())
def
test_repeat_complex
(
self
):
class
AddOne
(
nn
.
Module
):
def
forward
(
self
,
x
):
return
x
+
1
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
nn
.
LayerChoice
([
AddOne
(),
nn
.
Identity
()],
label
=
'lc'
),
(
3
,
5
),
label
=
'rep'
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
2
)
self
.
assertEqual
(
set
([
mutator
.
label
for
mutator
in
mutators
]),
{
'lc'
,
'rep'
})
sampler
=
RandomSampler
()
for
_
in
range
(
10
):
new_model
=
model
for
mutator
in
mutators
:
new_model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
new_model
)
result
=
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
1
)).
item
()
self
.
assertIn
(
result
,
[
0.
,
3.
,
4.
,
5.
])
# independent layer choice
@
model_wrapper
class
Net
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
block
=
nn
.
Repeat
(
lambda
index
:
nn
.
LayerChoice
([
AddOne
(),
nn
.
Identity
()]),
(
2
,
3
),
label
=
'rep'
)
def
forward
(
self
,
x
):
return
self
.
block
(
x
)
model
,
mutators
=
self
.
_get_model_with_mutators
(
Net
())
self
.
assertEqual
(
len
(
mutators
),
4
)
result
=
[]
for
_
in
range
(
20
):
new_model
=
model
for
mutator
in
mutators
:
new_model
=
mutator
.
bind_sampler
(
sampler
).
apply
(
new_model
)
result
.
append
(
self
.
_get_converted_pytorch_model
(
new_model
)(
torch
.
zeros
(
1
,
1
)).
item
())
self
.
assertIn
(
1.
,
result
)
def
test_cell
(
self
):
def
test_cell
(
self
):
@
self
.
get_serializer
()
@
self
.
get_serializer
()
class
Net
(
nn
.
Module
):
class
Net
(
nn
.
Module
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment