Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
319ff036
Unverified
Commit
319ff036
authored
Apr 24, 2020
by
Yuge Zhang
Committed by
GitHub
Apr 24, 2020
Browse files
Allow kwargs in layer choice (#2351)
parent
80242f2b
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
16 deletions
+19
-16
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
+4
-2
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+2
-5
src/sdk/pynni/nni/nas/pytorch/mutator.py
src/sdk/pynni/nni/nas/pytorch/mutator.py
+8
-6
src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py
src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py
+5
-3
No files found.
src/sdk/pynni/nni/nas/pytorch/base_mutator.py
View file @
319ff036
...
@@ -104,7 +104,7 @@ class BaseMutator(nn.Module):
...
@@ -104,7 +104,7 @@ class BaseMutator(nn.Module):
"""
"""
pass
pass
def
on_forward_layer_choice
(
self
,
mutable
,
*
input
s
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwarg
s
):
"""
"""
Callbacks of forward in LayerChoice.
Callbacks of forward in LayerChoice.
...
@@ -112,8 +112,10 @@ class BaseMutator(nn.Module):
...
@@ -112,8 +112,10 @@ class BaseMutator(nn.Module):
----------
----------
mutable : LayerChoice
mutable : LayerChoice
Module whose forward is called.
Module whose forward is called.
input
s : list of torch.Tensor
arg
s : list of torch.Tensor
The arguments of its forward function.
The arguments of its forward function.
kwargs : dict
The keyword arguments of its forward function.
Returns
Returns
-------
-------
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
319ff036
...
@@ -58,9 +58,6 @@ class Mutable(nn.Module):
...
@@ -58,9 +58,6 @@ class Mutable(nn.Module):
"Or did you apply multiple fixed architectures?"
)
"Or did you apply multiple fixed architectures?"
)
self
.
__dict__
[
"mutator"
]
=
mutator
self
.
__dict__
[
"mutator"
]
=
mutator
def
forward
(
self
,
*
inputs
):
raise
NotImplementedError
@
property
@
property
def
key
(
self
):
def
key
(
self
):
"""
"""
...
@@ -155,14 +152,14 @@ class LayerChoice(Mutable):
...
@@ -155,14 +152,14 @@ class LayerChoice(Mutable):
self
.
reduction
=
reduction
self
.
reduction
=
reduction
self
.
return_mask
=
return_mask
self
.
return_mask
=
return_mask
def
forward
(
self
,
*
input
s
):
def
forward
(
self
,
*
args
,
**
kwarg
s
):
"""
"""
Returns
Returns
-------
-------
tuple of tensors
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
"""
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
input
s
)
out
,
mask
=
self
.
mutator
.
on_forward_layer_choice
(
self
,
*
args
,
**
kwarg
s
)
if
self
.
return_mask
:
if
self
.
return_mask
:
return
out
,
mask
return
out
,
mask
return
out
return
out
...
...
src/sdk/pynni/nni/nas/pytorch/mutator.py
View file @
319ff036
...
@@ -128,7 +128,7 @@ class Mutator(BaseMutator):
...
@@ -128,7 +128,7 @@ class Mutator(BaseMutator):
result
[
"mutable"
][
mutable
.
key
].
append
(
path
)
result
[
"mutable"
][
mutable
.
key
].
append
(
path
)
return
result
return
result
def
on_forward_layer_choice
(
self
,
mutable
,
*
input
s
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwarg
s
):
"""
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
On default, this method retrieves the decision obtained previously, and select certain operations.
Only operations with non-zero weight will be executed. The results will be added to a list.
Only operations with non-zero weight will be executed. The results will be added to a list.
...
@@ -138,7 +138,9 @@ class Mutator(BaseMutator):
...
@@ -138,7 +138,9 @@ class Mutator(BaseMutator):
----------
----------
mutable : LayerChoice
mutable : LayerChoice
Layer choice module.
Layer choice module.
inputs : list of torch.Tensor
args : list of torch.Tensor
Inputs
kwargs : dict
Inputs
Inputs
Returns
Returns
...
@@ -148,16 +150,16 @@ class Mutator(BaseMutator):
...
@@ -148,16 +150,16 @@ class Mutator(BaseMutator):
"""
"""
if
self
.
_connect_all
:
if
self
.
_connect_all
:
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
return
self
.
_all_connect_tensor_reduction
(
mutable
.
reduction
,
[
op
(
*
input
s
)
for
op
in
mutable
.
choices
]),
\
[
op
(
*
args
,
**
kwarg
s
)
for
op
in
mutable
.
choices
]),
\
torch
.
ones
(
mutable
.
length
)
torch
.
ones
(
mutable
.
length
)
def
_map_fn
(
op
,
*
input
s
):
def
_map_fn
(
op
,
args
,
kwarg
s
):
return
op
(
*
input
s
)
return
op
(
*
args
,
**
kwarg
s
)
mask
=
self
.
_get_decision
(
mutable
)
mask
=
self
.
_get_decision
(
mutable
)
assert
len
(
mask
)
==
len
(
mutable
.
choices
),
\
assert
len
(
mask
)
==
len
(
mutable
.
choices
),
\
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
.
choices
))
"Invalid mask, expected {} to be of length {}."
.
format
(
mask
,
len
(
mutable
.
choices
))
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
*
input
s
)
for
choice
in
mutable
.
choices
],
mask
)
out
=
self
.
_select_with_mask
(
_map_fn
,
[(
choice
,
args
,
kwarg
s
)
for
choice
in
mutable
.
choices
],
mask
)
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
return
self
.
_tensor_reduction
(
mutable
.
reduction
,
out
),
mask
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
def
on_forward_input_choice
(
self
,
mutable
,
tensor_list
):
...
...
src/sdk/pynni/nni/nas/pytorch/proxylessnas/mutator.py
View file @
319ff036
...
@@ -317,7 +317,7 @@ class ProxylessNasMutator(BaseMutator):
...
@@ -317,7 +317,7 @@ class ProxylessNasMutator(BaseMutator):
self
.
mutable_list
.
append
(
mutable
)
self
.
mutable_list
.
append
(
mutable
)
mutable
.
registered_module
=
MixedOp
(
mutable
)
mutable
.
registered_module
=
MixedOp
(
mutable
)
def
on_forward_layer_choice
(
self
,
mutable
,
*
input
s
):
def
on_forward_layer_choice
(
self
,
mutable
,
*
args
,
**
kwarg
s
):
"""
"""
Callback of layer choice forward. This function defines the forward
Callback of layer choice forward. This function defines the forward
logic of the input mutable. So mutable is only interface, its real
logic of the input mutable. So mutable is only interface, its real
...
@@ -327,7 +327,9 @@ class ProxylessNasMutator(BaseMutator):
...
@@ -327,7 +327,9 @@ class ProxylessNasMutator(BaseMutator):
----------
----------
mutable: LayerChoice
mutable: LayerChoice
forward logic of this input mutable
forward logic of this input mutable
inputs: list of torch.Tensor
args: list of torch.Tensor
inputs of this mutable
kwargs: dict
inputs of this mutable
inputs of this mutable
Returns
Returns
...
@@ -339,7 +341,7 @@ class ProxylessNasMutator(BaseMutator):
...
@@ -339,7 +341,7 @@ class ProxylessNasMutator(BaseMutator):
"""
"""
# FIXME: return mask, to be consistent with other algorithms
# FIXME: return mask, to be consistent with other algorithms
idx
=
mutable
.
registered_module
.
active_op_index
idx
=
mutable
.
registered_module
.
active_op_index
return
mutable
.
registered_module
(
mutable
,
*
input
s
),
idx
return
mutable
.
registered_module
(
mutable
,
*
args
,
**
kwarg
s
),
idx
def
reset_binary_gates
(
self
):
def
reset_binary_gates
(
self
):
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment