Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
319ff036
Unverified
Commit
319ff036
authored
Apr 24, 2020
by
Yuge Zhang
Committed by
GitHub
Apr 24, 2020
Browse files
Allow kwargs in layer choice (#2351)
parent
80242f2b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
16 deletions
+19
-16
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
+4
-2
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+2
-5
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+8
-6
src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py
src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py
+5
-3
No files found.
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
View file @
319ff036
...
...
@@ -104,7 +104,7 @@ class BaseMutator(nn.Module):
"""
pass
def
on_forward_layer_choice
(
self
,
mutable
,
*
input
s
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwarg
s
):
"""
Callbacks of forward in LayerChoice.
...
...
@@ -112,8 +112,10 @@ class BaseMutator(nn.Module):
----------
mutable : LayerChoice
Module whose forward is called.
input
s : list of torch.Tensor
arg
s : list of torch.Tensor
The arguments of its forward function.
kwargs : dict
The keyword arguments of its forward function.
Returns
-------
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
319ff036
...
...
@@ -58,9 +58,6 @@ class Mutable(nn.Module):
"Or did you apply multiple fixed architectures?"
)
self
.
__dict__
[
"mutator"
]
=
mutator
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
@
property
def
key
(
self
):
"""
...
...
@@ -155,14 +152,14 @@ class LayerChoice(Mutable):
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
def
forward
(
self
,
*
input
s
):
def
forward
(
self
,
*
args
,
**
kwarg
s
):
"""
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
input
s
)
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
args
,
**
kwarg
s
)
if
self
.
return_mask
:
return
out
,
mask
return
out
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
319ff036
...
...
@@ -128,7 +128,7 @@ class Mutator(BaseMutator):
result
[
"mutable"
][
mutable
.
key
].
append
(
path
)
return
result
def
on_forward_layer_choice
(
self
,
mutable
,
*
input
s
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwarg
s
):
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
Only operations with non-zero weight will be executed. The results will be added to a list.
...
...
@@ -138,7 +138,9 @@ class Mutator(BaseMutator):
----------
mutable : LayerChoice
Layer choice module.
inputs : list of torch.Tensor
args : list of torch.Tensor
Inputs
kwargs : dict
Inputs
Returns
...
...
@@ -148,16 +150,16 @@ class Mutator(BaseMutator):
"""
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
[
op
(
*
input
s
)
for
op
in
mutable
.
choices
]),
\
[
op
(
*
args
,
**
kwarg
s
)
for
op
in
mutable
.
choices
]),
\
torch
.
ones
(
mutable
.
length
)
def
_map_fn
(
op
,
*
input
s
):
return
op
(
*
input
s
)
def
_map_fn
(
op
,
args
,
kwarg
s
):
return
op
(
*
args
,
**
kwarg
s
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
.
choices
),
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
.
choices
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
input
s
)
for
choice
in
mutable
.
choices
],
mask
)
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwarg
s
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
...
...
src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py
View file @
319ff036
...
...
@@ -317,7 +317,7 @@ class ProxylessNasMutator(BaseMutator):
self
.
mutable_list
.
append
(
mutable
)
mutable
.
registered_module
=
MixedOp
(
mutable
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
input
s
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwarg
s
):
"""
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
...
...
@@ -327,7 +327,9 @@ class ProxylessNasMutator(BaseMutator):
----------
mutable: LayerChoice
forward logic of this input mutable
inputs: list of torch.Tensor
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
Returns
...
...
@@ -339,7 +341,7 @@ class ProxylessNasMutator(BaseMutator):
"""
# FIXME: return mask, to be consistent with other algorithms
idx
=
mutable
.
registered_module
.
active_op_index
return
mutable
.
registered_module
(
mutable
,
*
input
s
),
idx
return
mutable
.
registered_module
(
mutable
,
*
args
,
**
kwarg
s
),
idx
def
reset_binary_gates
(
self
):
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment