Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
9e2a069d
"...composable_kernel.git" did not exist on "dd10cb1437f59dda600eeca52ee241f82b085be9"
Unverified
Commit
9e2a069d
authored
Aug 29, 2022
by
Maze
Committed by
GitHub
Aug 29, 2022
Browse files
One-shot sub state dict implementation (#5054)
parent
d68691d0
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
436 additions
and
124 deletions
+436
-124
nni/nas/hub/pytorch/autoformer.py
nni/nas/hub/pytorch/autoformer.py
+18
-7
nni/nas/oneshot/pytorch/base_lightning.py
nni/nas/oneshot/pytorch/base_lightning.py
+7
-3
nni/nas/oneshot/pytorch/sampling.py
nni/nas/oneshot/pytorch/sampling.py
+41
-2
nni/nas/oneshot/pytorch/strategy.py
nni/nas/oneshot/pytorch/strategy.py
+26
-10
nni/nas/oneshot/pytorch/supermodule/base.py
nni/nas/oneshot/pytorch/supermodule/base.py
+70
-2
nni/nas/oneshot/pytorch/supermodule/operation.py
nni/nas/oneshot/pytorch/supermodule/operation.py
+194
-97
nni/nas/oneshot/pytorch/supermodule/sampling.py
nni/nas/oneshot/pytorch/supermodule/sampling.py
+19
-2
test/algo/nas/test_oneshot.py
test/algo/nas/test_oneshot.py
+19
-0
test/algo/nas/test_oneshot_supermodules.py
test/algo/nas/test_oneshot_supermodules.py
+42
-1
No files found.
nni/nas/hub/pytorch/autoformer.py
View file @
9e2a069d
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
typing
import
Optional
,
Tuple
,
cast
,
Any
,
Dict
from
typing
import
Optional
,
Tuple
,
cast
,
Any
,
Dict
,
Union
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
...
@@ -135,7 +135,7 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -135,7 +135,7 @@ class TransformerEncoderLayer(nn.Module):
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
"""
"""
def
__init__
(
def
__init__
(
self
,
embed_dim
,
num_heads
,
mlp_ratio
=
4.
,
self
,
embed_dim
,
num_heads
,
mlp_ratio
:
Union
[
int
,
float
,
nn
.
ValueChoice
]
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
rpe
=
False
,
qkv_bias
=
False
,
qk_scale
=
None
,
rpe
=
False
,
drop_rate
=
0.
,
attn_drop
=
0.
,
proj_drop
=
0.
,
drop_path
=
0.
,
drop_rate
=
0.
,
attn_drop
=
0.
,
proj_drop
=
0.
,
drop_path
=
0.
,
pre_norm
=
True
,
rpe_length
=
14
,
head_dim
=
64
pre_norm
=
True
,
rpe_length
=
14
,
head_dim
=
64
...
@@ -235,13 +235,18 @@ class MixedClsToken(MixedOperation, ClsToken):
...
@@ -235,13 +235,18 @@ class MixedClsToken(MixedOperation, ClsToken):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
embed_dim
,
def
slice_param
(
self
,
embed_dim
,
**
kwargs
)
->
Any
:
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embed_dim_
=
_W
(
embed_dim
)
embed_dim_
=
_W
(
embed_dim
)
cls_token
=
_S
(
self
.
cls_token
)[...,
:
embed_dim_
]
cls_token
=
_S
(
self
.
cls_token
)[...,
:
embed_dim_
]
return
torch
.
cat
((
cls_token
.
expand
(
inputs
.
shape
[
0
],
-
1
,
-
1
),
inputs
),
dim
=
1
)
return
{
'cls_token'
:
cls_token
}
def
forward_with_args
(
self
,
embed_dim
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
cls_token
=
self
.
slice_param
(
embed_dim
)[
'cls_token'
]
assert
isinstance
(
cls_token
,
torch
.
Tensor
)
return
torch
.
cat
((
cls_token
.
expand
(
inputs
.
shape
[
0
],
-
1
,
-
1
),
inputs
),
dim
=
1
)
@
basic_unit
@
basic_unit
class
AbsPosEmbed
(
nn
.
Module
):
class
AbsPosEmbed
(
nn
.
Module
):
...
@@ -271,11 +276,17 @@ class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
...
@@ -271,11 +276,17 @@ class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
embed_dim
,
def
slice_param
(
self
,
embed_dim
,
**
kwargs
)
->
Any
:
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embed_dim_
=
_W
(
embed_dim
)
embed_dim_
=
_W
(
embed_dim
)
pos_embed
=
_S
(
self
.
pos_embed
)[...,
:
embed_dim_
]
pos_embed
=
_S
(
self
.
pos_embed
)[...,
:
embed_dim_
]
return
{
'pos_embed'
:
pos_embed
}
def
forward_with_args
(
self
,
embed_dim
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
pos_embed
=
self
.
slice_param
(
embed_dim
)[
'pos_embed'
]
assert
isinstance
(
pos_embed
,
torch
.
Tensor
)
return
inputs
+
pos_embed
return
inputs
+
pos_embed
...
...
nni/nas/oneshot/pytorch/base_lightning.py
View file @
9e2a069d
...
@@ -77,7 +77,6 @@ def traverse_and_mutate_submodules(
...
@@ -77,7 +77,6 @@ def traverse_and_mutate_submodules(
memo
=
{}
memo
=
{}
module_list
=
[]
module_list
=
[]
def
apply
(
m
):
def
apply
(
m
):
# Need to call list() here because the loop body might replace some children in-place.
# Need to call list() here because the loop body might replace some children in-place.
for
name
,
child
in
list
(
m
.
named_children
()):
for
name
,
child
in
list
(
m
.
named_children
()):
...
@@ -280,16 +279,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -280,16 +279,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
result
.
update
(
module
.
search_space_spec
())
result
.
update
(
module
.
search_space_spec
())
return
result
return
result
def
resample
(
self
)
->
dict
[
str
,
Any
]:
def
resample
(
self
,
memo
=
None
)
->
dict
[
str
,
Any
]:
"""Trigger the resample for each :attr:`nas_modules`.
"""Trigger the resample for each :attr:`nas_modules`.
Sometimes (e.g., in differentiable cases), it does nothing.
Sometimes (e.g., in differentiable cases), it does nothing.
Parameters
----------
memo : dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
Returns
-------
-------
dict
dict
Sampled architecture.
Sampled architecture.
"""
"""
result
=
{}
result
=
memo
or
{}
for
module
in
self
.
nas_modules
:
for
module
in
self
.
nas_modules
:
result
.
update
(
module
.
resample
(
memo
=
result
))
result
.
update
(
module
.
resample
(
memo
=
result
))
return
result
return
result
...
...
nni/nas/oneshot/pytorch/sampling.py
View file @
9e2a069d
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
import
warnings
from
typing
import
Any
,
cast
from
typing
import
Any
,
cast
,
Dict
import
pytorch_lightning
as
pl
import
pytorch_lightning
as
pl
import
torch
import
torch
...
@@ -19,7 +19,7 @@ from .supermodule.sampling import (
...
@@ -19,7 +19,7 @@ from .supermodule.sampling import (
PathSamplingCell
,
PathSamplingRepeat
PathSamplingCell
,
PathSamplingRepeat
)
)
from
.enas
import
ReinforceController
,
ReinforceField
from
.enas
import
ReinforceController
,
ReinforceField
from
.supermodule.base
import
sub_state_dict
class
RandomSamplingLightningModule
(
BaseOneShotLightningModule
):
class
RandomSamplingLightningModule
(
BaseOneShotLightningModule
):
_random_note
=
"""
_random_note
=
"""
...
@@ -92,6 +92,45 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
...
@@ -92,6 +92,45 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
)
)
return
super
().
export
()
return
super
().
export
()
def
_get_base_model
(
self
):
assert
isinstance
(
self
.
model
.
model
,
nn
.
Module
)
base_model
:
nn
.
Module
=
self
.
model
.
model
return
base_model
def
state_dict
(
self
,
destination
:
Any
=
None
,
prefix
:
str
=
''
,
keep_vars
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
base_model
=
self
.
_get_base_model
()
state_dict
=
base_model
.
state_dict
(
destination
=
destination
,
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict
def
load_state_dict
(
self
,
state_dict
,
strict
:
bool
=
True
)
->
None
:
base_model
=
self
.
_get_base_model
()
base_model
.
load_state_dict
(
state_dict
=
state_dict
,
strict
=
strict
)
def
sub_state_dict
(
self
,
arch
:
dict
[
str
,
Any
],
destination
:
Any
=
None
,
prefix
:
str
=
''
,
keep_vars
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
"""Given the architecture dict, return the state_dict which can be directly loaded by the fixed subnet.
Parameters
----------
arch : dict[str, Any]
subnet architecture dict.
destination: dict
If provided, the state of module will be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
prefix: str
A prefix added to parameter and buffer names to compose the keys in state_dict.
keep_vars: bool
by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd.
If it's set to ``True``, detaching will not be performed.
Returns
-------
dict
Subnet state dict.
"""
self
.
resample
(
memo
=
arch
)
base_model
=
self
.
_get_base_model
()
state_dict
=
sub_state_dict
(
base_model
,
destination
,
prefix
,
keep_vars
)
return
state_dict
class
EnasLightningModule
(
RandomSamplingLightningModule
):
class
EnasLightningModule
(
RandomSamplingLightningModule
):
_enas_note
=
"""
_enas_note
=
"""
...
...
nni/nas/oneshot/pytorch/strategy.py
View file @
9e2a069d
...
@@ -13,7 +13,7 @@ When adding/modifying a new strategy in this file, don't forget to link it in st
...
@@ -13,7 +13,7 @@ When adding/modifying a new strategy in this file, don't forget to link it in st
from
__future__
import
annotations
from
__future__
import
annotations
import
warnings
import
warnings
from
typing
import
Any
,
Type
from
typing
import
Any
,
Type
,
Union
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -48,31 +48,43 @@ class OneShotStrategy(BaseStrategy):
...
@@ -48,31 +48,43 @@ class OneShotStrategy(BaseStrategy):
"""
"""
return
train_dataloaders
,
val_dataloaders
return
train_dataloaders
,
val_dataloaders
def
attach_model
(
self
,
base_model
:
Union
[
Model
,
nn
.
Module
]):
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if
isinstance
(
base_model
,
Model
):
if
not
isinstance
(
base_model
.
python_object
,
nn
.
Module
):
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
py_model
:
nn
.
Module
=
base_model
.
python_object
if
not
isinstance
(
base_model
.
evaluator
,
Lightning
):
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
else
:
from
nni.retiarii.evaluator.pytorch.lightning
import
ClassificationModule
evaluator_module
=
ClassificationModule
()
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
base_model
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
def
run
(
self
,
base_model
:
Model
,
applied_mutators
):
# one-shot strategy doesn't use ``applied_mutators``
# one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own
# but get the "mutators" on their own
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
_reason
=
'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if
not
isinstance
(
base_model
.
python_object
,
nn
.
Module
):
raise
TypeError
(
'Model is not a nn.Module. '
+
_reason
)
py_model
:
nn
.
Module
=
base_model
.
python_object
if
applied_mutators
:
if
applied_mutators
:
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
raise
ValueError
(
'Mutator is not empty. '
+
_reason
)
if
not
isinstance
(
base_model
.
evaluator
,
Lightning
):
if
not
isinstance
(
base_model
.
evaluator
,
Lightning
):
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
raise
TypeError
(
'Evaluator needs to be a lightning evaluator to make one-shot strategy work.'
)
evaluator_module
:
LightningModule
=
base_model
.
evaluator
.
module
self
.
attach_model
(
base_model
)
evaluator_module
.
running_mode
=
'oneshot'
evaluator_module
.
set_model
(
py_model
)
self
.
model
=
self
.
oneshot_module
(
evaluator_module
,
**
self
.
oneshot_kwargs
)
evaluator
:
Lightning
=
base_model
.
evaluator
evaluator
:
Lightning
=
base_model
.
evaluator
if
evaluator
.
train_dataloaders
is
None
or
evaluator
.
val_dataloaders
is
None
:
if
evaluator
.
train_dataloaders
is
None
or
evaluator
.
val_dataloaders
is
None
:
raise
TypeError
(
'Training and validation dataloader are both required to set in evaluator for one-shot strategy.'
)
raise
TypeError
(
'Training and validation dataloader are both required to set in evaluator for one-shot strategy.'
)
train_loader
,
val_loader
=
self
.
preprocess_dataloader
(
evaluator
.
train_dataloaders
,
evaluator
.
val_dataloaders
)
train_loader
,
val_loader
=
self
.
preprocess_dataloader
(
evaluator
.
train_dataloaders
,
evaluator
.
val_dataloaders
)
assert
isinstance
(
self
.
model
,
BaseOneShotLightningModule
)
evaluator
.
trainer
.
fit
(
self
.
model
,
train_loader
,
val_loader
)
evaluator
.
trainer
.
fit
(
self
.
model
,
train_loader
,
val_loader
)
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
list
[
Any
]:
def
export_top_models
(
self
,
top_k
:
int
=
1
)
->
list
[
Any
]:
...
@@ -144,3 +156,7 @@ class RandomOneShot(OneShotStrategy):
...
@@ -144,3 +156,7 @@ class RandomOneShot(OneShotStrategy):
def
__init__
(
self
,
**
kwargs
):
def
__init__
(
self
,
**
kwargs
):
super
().
__init__
(
RandomSamplingLightningModule
,
**
kwargs
)
super
().
__init__
(
RandomSamplingLightningModule
,
**
kwargs
)
def
sub_state_dict
(
self
,
arch
:
dict
[
str
,
Any
]):
assert
isinstance
(
self
.
model
,
RandomSamplingLightningModule
)
return
self
.
model
.
sub_state_dict
(
arch
)
\ No newline at end of file
nni/nas/oneshot/pytorch/supermodule/base.py
View file @
9e2a069d
...
@@ -3,13 +3,62 @@
...
@@ -3,13 +3,62 @@
from
__future__
import
annotations
from
__future__
import
annotations
from
typing
import
Any
from
collections
import
OrderedDict
import
itertools
from
typing
import
Any
,
Dict
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.common.hpo_utils
import
ParameterSpec
from
nni.common.hpo_utils
import
ParameterSpec
__all__
=
[
'BaseSuperNetModule'
]
__all__
=
[
'BaseSuperNetModule'
,
'sub_state_dict'
]
def
sub_state_dict
(
module
:
Any
,
destination
:
Any
=
None
,
prefix
:
str
=
''
,
keep_vars
:
bool
=
False
)
->
Dict
[
str
,
Any
]:
"""Returns a dictionary containing a whole state of the BaseSuperNetModule.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Parameters
----------
arch : dict[str, Any]
subnet architecture dict.
destination (dict, optional):
If provided, the state of module will be updated into the dict
and the same object is returned. Otherwise, an ``OrderedDict``
will be created and returned. Default: ``None``.
prefix (str, optional):
a prefix added to parameter and buffer names to compose the keys in state_dict.
Default: ``''``.
keep_vars (bool, optional):
by default the :class:`~torch.Tensor` s returned in the state dict are
detached from autograd. If it's set to ``True``, detaching will not be performed.
Default: ``False``.
Returns
-------
dict
Subnet state dictionary.
"""
if
destination
is
None
:
destination
=
OrderedDict
()
destination
.
_metadata
=
OrderedDict
()
local_metadata
=
dict
(
version
=
module
.
_version
)
if
hasattr
(
destination
,
"_metadata"
):
destination
.
_metadata
[
prefix
[:
-
1
]]
=
local_metadata
if
isinstance
(
module
,
BaseSuperNetModule
):
module
.
_save_to_sub_state_dict
(
destination
,
prefix
,
keep_vars
)
else
:
module
.
_save_to_state_dict
(
destination
,
prefix
,
keep_vars
)
for
name
,
m
in
module
.
_modules
.
items
():
if
m
is
not
None
:
sub_state_dict
(
m
,
destination
=
destination
,
prefix
=
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
return
destination
class
BaseSuperNetModule
(
nn
.
Module
):
class
BaseSuperNetModule
(
nn
.
Module
):
...
@@ -104,3 +153,22 @@ class BaseSuperNetModule(nn.Module):
...
@@ -104,3 +153,22 @@ class BaseSuperNetModule(nn.Module):
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
"""
"""
raise
NotImplementedError
()
raise
NotImplementedError
()
def
_save_param_buff_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
"""Save the params and buffers of the current module to state dict."""
for
name
,
value
in
itertools
.
chain
(
self
.
_parameters
.
items
(),
self
.
_buffers
.
items
()):
# direct children
if
value
is
None
or
name
in
self
.
_non_persistent_buffers_set
:
# it won't appear in state dict
continue
destination
[
prefix
+
name
]
=
value
if
keep_vars
else
value
.
detach
()
def
_save_module_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
"""Save the sub-module to state dict."""
for
name
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
sub_state_dict
(
module
,
destination
=
destination
,
prefix
=
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
def
_save_to_sub_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
"""Save to state dict."""
self
.
_save_param_buff_to_state_dict
(
destination
,
prefix
,
keep_vars
)
self
.
_save_module_to_state_dict
(
destination
,
prefix
,
keep_vars
)
nni/nas/oneshot/pytorch/supermodule/operation.py
View file @
9e2a069d
...
@@ -23,7 +23,7 @@ from nni.common.hpo_utils import ParameterSpec
...
@@ -23,7 +23,7 @@ from nni.common.hpo_utils import ParameterSpec
from
nni.common.serializer
import
is_traceable
from
nni.common.serializer
import
is_traceable
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
.base
import
BaseSuperNetModule
from
.base
import
BaseSuperNetModule
,
sub_state_dict
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
,
evaluate_constant
from
._valuechoice_utils
import
traverse_all_options
,
dedup_inner_choices
,
evaluate_constant
from
._operation_utils
import
Slicable
as
_S
,
MaybeWeighted
as
_W
,
int_or_int_dict
,
scalar_or_scalar_dict
from
._operation_utils
import
Slicable
as
_S
,
MaybeWeighted
as
_W
,
int_or_int_dict
,
scalar_or_scalar_dict
...
@@ -232,6 +232,22 @@ class MixedOperation(BaseSuperNetModule):
...
@@ -232,6 +232,22 @@ class MixedOperation(BaseSuperNetModule):
if
param
.
default
is
not
param
.
empty
and
param
.
name
not
in
self
.
init_arguments
:
if
param
.
default
is
not
param
.
empty
and
param
.
name
not
in
self
.
init_arguments
:
self
.
init_arguments
[
param
.
name
]
=
param
.
default
self
.
init_arguments
[
param
.
name
]
=
param
.
default
def
slice_param
(
self
,
**
kwargs
):
"""Slice the params and buffers for subnet forward and state dict.
When there is a `mapping=True` in kwargs, the return result will be wrapped in dict.
"""
raise
NotImplementedError
()
def
_save_param_buff_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
kwargs
=
{
name
:
self
.
forward_argument
(
name
)
for
name
in
self
.
argument_list
}
params_mapping
:
dict
[
str
,
Any
]
=
self
.
slice_param
(
**
kwargs
)
for
name
,
value
in
itertools
.
chain
(
self
.
_parameters
.
items
(),
self
.
_buffers
.
items
()):
# direct children
if
value
is
None
or
name
in
self
.
_non_persistent_buffers_set
:
# it won't appear in state dict
continue
value
=
params_mapping
.
get
(
name
,
value
)
destination
[
prefix
+
name
]
=
value
if
keep_vars
else
value
.
detach
()
class
MixedLinear
(
MixedOperation
,
nn
.
Linear
):
class
MixedLinear
(
MixedOperation
,
nn
.
Linear
):
"""Mixed linear operation.
"""Mixed linear operation.
...
@@ -250,20 +266,23 @@ class MixedLinear(MixedOperation, nn.Linear):
...
@@ -250,20 +266,23 @@ class MixedLinear(MixedOperation, nn.Linear):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
def
slice_param
(
self
,
in_features
:
int_or_int_dict
,
out_features
:
int_or_int_dict
,
**
kwargs
)
->
Any
:
in_features
:
int_or_int_dict
,
out_features
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
in_features_
=
_W
(
in_features
)
in_features_
=
_W
(
in_features
)
out_features_
=
_W
(
out_features
)
out_features_
=
_W
(
out_features
)
weight
=
_S
(
self
.
weight
)[:
out_features_
]
weight
=
_S
(
self
.
weight
)[:
out_features_
]
weight
=
_S
(
weight
)[:,
:
in_features_
]
weight
=
_S
(
weight
)[:,
:
in_features_
]
if
self
.
bias
is
None
:
bias
=
self
.
bias
if
self
.
bias
is
None
else
_S
(
self
.
bias
)[:
out_features_
]
bias
=
self
.
bias
else
:
return
{
'weight'
:
weight
,
'bias'
:
bias
}
bias
=
_S
(
self
.
bias
)[:
out_features_
]
def
forward_with_args
(
self
,
in_features
:
int_or_int_dict
,
out_features
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
params_mapping
=
self
.
slice_param
(
in_features
,
out_features
)
weight
,
bias
=
[
params_mapping
.
get
(
name
)
for
name
in
[
'weight'
,
'bias'
]]
return
F
.
linear
(
inputs
,
weight
,
bias
)
return
F
.
linear
(
inputs
,
weight
,
bias
)
...
@@ -347,19 +366,13 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
...
@@ -347,19 +366,13 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
else
:
else
:
return
max
(
traverse_all_options
(
value_choice
))
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
def
slice_param
(
self
,
in_channels
:
int_or_int_dict
,
in_channels
:
int_or_int_dict
,
out_channels
:
int_or_int_dict
,
out_channels
:
int_or_int_dict
,
kernel_size
:
scalar_or_scalar_dict
[
_int_or_tuple
],
kernel_size
:
scalar_or_scalar_dict
[
_int_or_tuple
],
stride
:
_int_or_tuple
,
groups
:
int_or_int_dict
,
padding
:
scalar_or_scalar_dict
[
_int_or_tuple
],
**
kwargs
dilation
:
int
,
)
->
Any
:
groups
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'stride, dilation'
,
'Conv2d'
))
in_channels_
=
_W
(
in_channels
)
in_channels_
=
_W
(
in_channels
)
out_channels_
=
_W
(
out_channels
)
out_channels_
=
_W
(
out_channels
)
...
@@ -369,6 +382,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
...
@@ -369,6 +382,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if
not
isinstance
(
groups
,
dict
):
if
not
isinstance
(
groups
,
dict
):
weight
=
_S
(
weight
)[:,
:
in_channels_
//
groups
]
weight
=
_S
(
weight
)[:,
:
in_channels_
//
groups
]
# palce holder
in_channels_per_group
=
None
else
:
else
:
assert
'groups'
in
self
.
mutable_arguments
assert
'groups'
in
self
.
mutable_arguments
err_message
=
'For differentiable one-shot strategy, when groups is a ValueChoice, '
\
err_message
=
'For differentiable one-shot strategy, when groups is a ValueChoice, '
\
...
@@ -383,15 +398,51 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
...
@@ -383,15 +398,51 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
raise
ValueError
(
err_message
)
raise
ValueError
(
err_message
)
if
in_channels_per_group
!=
int
(
in_channels_per_group
):
if
in_channels_per_group
!=
int
(
in_channels_per_group
):
raise
ValueError
(
f
'Input channels per group is found to be a non-integer:
{
in_channels_per_group
}
'
)
raise
ValueError
(
f
'Input channels per group is found to be a non-integer:
{
in_channels_per_group
}
'
)
# Compute sliced weights and groups (as an integer)
weight
=
_S
(
weight
)[:,
:
int
(
in_channels_per_group
)]
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a_
,
kernel_b_
=
_W
(
kernel_a
),
_W
(
kernel_b
)
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a_
)
//
2
,
(
max_kernel_b
-
kernel_b_
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a_
,
kernel_b_top
:
kernel_b_top
+
kernel_b_
]
bias
=
_S
(
self
.
bias
)[:
out_channels_
]
if
self
.
bias
is
not
None
else
None
return
{
'weight'
:
weight
,
'bias'
:
bias
,
'in_channels_per_group'
:
in_channels_per_group
}
def
forward_with_args
(
self
,
in_channels
:
int_or_int_dict
,
out_channels
:
int_or_int_dict
,
kernel_size
:
scalar_or_scalar_dict
[
_int_or_tuple
],
stride
:
_int_or_tuple
,
padding
:
scalar_or_scalar_dict
[
_int_or_tuple
],
dilation
:
int
,
groups
:
int_or_int_dict
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
stride
,
dilation
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'stride, dilation'
,
'Conv2d'
))
params_mapping
=
self
.
slice_param
(
in_channels
,
out_channels
,
kernel_size
,
groups
)
weight
,
bias
,
in_channels_per_group
=
[
params_mapping
.
get
(
name
)
for
name
in
[
'weight'
,
'bias'
,
'in_channels_per_group'
]
]
if
isinstance
(
groups
,
dict
):
if
not
isinstance
(
in_channels_per_group
,
(
int
,
float
)):
raise
ValueError
(
f
'Input channels per group is found to be a non-numberic:
{
in_channels_per_group
}
'
)
if
inputs
.
size
(
1
)
%
in_channels_per_group
!=
0
:
if
inputs
.
size
(
1
)
%
in_channels_per_group
!=
0
:
raise
RuntimeError
(
raise
RuntimeError
(
f
'Input channels must be divisible by in_channels_per_group, but the input shape is
{
inputs
.
size
()
}
, '
f
'Input channels must be divisible by in_channels_per_group, but the input shape is
{
inputs
.
size
()
}
, '
f
'while in_channels_per_group =
{
in_channels_per_group
}
'
f
'while in_channels_per_group =
{
in_channels_per_group
}
'
)
)
else
:
# Compute sliced weights and groups (as an integer)
groups
=
inputs
.
size
(
1
)
//
int
(
in_channels_per_group
)
weight
=
_S
(
weight
)[:,
:
int
(
in_channels_per_group
)]
groups
=
inputs
.
size
(
1
)
//
int
(
in_channels_per_group
)
# slice center
# slice center
if
isinstance
(
kernel_size
,
dict
):
if
isinstance
(
kernel_size
,
dict
):
...
@@ -400,14 +451,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
...
@@ -400,14 +451,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
raise
ValueError
(
f
'Use "
{
self
.
padding
}
" in padding is not supported.'
)
raise
ValueError
(
f
'Use "
{
self
.
padding
}
" in padding is not supported.'
)
padding
=
self
.
padding
# max padding, must be a tuple
padding
=
self
.
padding
# max padding, must be a tuple
kernel_a
,
kernel_b
=
self
.
_to_tuple
(
kernel_size
)
kernel_a_
,
kernel_b_
=
_W
(
kernel_a
),
_W
(
kernel_b
)
max_kernel_a
,
max_kernel_b
=
self
.
kernel_size
# self.kernel_size must be a tuple
kernel_a_left
,
kernel_b_top
=
(
max_kernel_a
-
kernel_a_
)
//
2
,
(
max_kernel_b
-
kernel_b_
)
//
2
weight
=
_S
(
weight
)[:,
:,
kernel_a_left
:
kernel_a_left
+
kernel_a_
,
kernel_b_top
:
kernel_b_top
+
kernel_b_
]
bias
=
_S
(
self
.
bias
)[:
out_channels_
]
if
self
.
bias
is
not
None
else
None
# The rest parameters only need to be converted to tuple
# The rest parameters only need to be converted to tuple
stride_
=
self
.
_to_tuple
(
stride
)
stride_
=
self
.
_to_tuple
(
stride
)
dilation_
=
self
.
_to_tuple
(
dilation
)
dilation_
=
self
.
_to_tuple
(
dilation
)
...
@@ -441,6 +484,21 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
...
@@ -441,6 +484,21 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
return
max
(
traverse_all_options
(
value_choice
))
def
slice_param
(
self
,
num_features
:
int_or_int_dict
,
**
kwargs
)
->
Any
:
if
isinstance
(
num_features
,
dict
):
num_features
=
self
.
num_features
weight
,
bias
=
self
.
weight
,
self
.
bias
running_mean
,
running_var
=
self
.
running_mean
,
self
.
running_var
if
num_features
<
self
.
num_features
:
weight
=
weight
[:
num_features
]
bias
=
bias
[:
num_features
]
running_mean
=
None
if
running_mean
is
None
else
running_mean
[:
num_features
]
running_var
=
None
if
running_var
is
None
else
running_var
[:
num_features
]
return
{
'weight'
:
weight
,
'bias'
:
bias
,
'running_mean'
:
running_mean
,
'running_var'
:
running_var
}
def
forward_with_args
(
self
,
def
forward_with_args
(
self
,
num_features
:
int_or_int_dict
,
num_features
:
int_or_int_dict
,
eps
:
float
,
eps
:
float
,
...
@@ -450,19 +508,11 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
...
@@ -450,19 +508,11 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
,
momentum
]):
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
,
momentum
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps and momentum'
,
'BatchNorm2d'
))
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps and momentum'
,
'BatchNorm2d'
))
if
isinstance
(
num_features
,
dict
):
params_mapping
=
self
.
slice_param
(
num_features
)
num_features
=
self
.
num_features
weight
,
bias
,
running_mean
,
running_var
=
[
params_mapping
.
get
(
name
)
weight
,
bias
=
self
.
weight
,
self
.
bias
for
name
in
[
'weight'
,
'bias'
,
'running_mean'
,
'running_var'
]
running_mean
,
running_var
=
self
.
running_mean
,
self
.
running_var
]
if
num_features
<
self
.
num_features
:
weight
=
weight
[:
num_features
]
bias
=
bias
[:
num_features
]
if
running_mean
is
not
None
:
running_mean
=
running_mean
[:
num_features
]
if
running_var
is
not
None
:
running_var
=
running_var
[:
num_features
]
if
self
.
training
:
if
self
.
training
:
bn_training
=
True
bn_training
=
True
...
@@ -481,6 +531,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
...
@@ -481,6 +531,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
eps
,
eps
,
)
)
class
MixedLayerNorm
(
MixedOperation
,
nn
.
LayerNorm
):
class
MixedLayerNorm
(
MixedOperation
,
nn
.
LayerNorm
):
"""
"""
Mixed LayerNorm operation.
Mixed LayerNorm operation.
...
@@ -517,14 +568,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
...
@@ -517,14 +568,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
else
:
else
:
return
max
(
all_sizes
)
return
max
(
all_sizes
)
def
forward_with_args
(
self
,
def
slice_param
(
self
,
normalized_shape
,
**
kwargs
)
->
Any
:
normalized_shape
,
eps
:
float
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps'
,
'LayerNorm'
))
if
isinstance
(
normalized_shape
,
dict
):
if
isinstance
(
normalized_shape
,
dict
):
normalized_shape
=
self
.
normalized_shape
normalized_shape
=
self
.
normalized_shape
...
@@ -541,6 +585,22 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
...
@@ -541,6 +585,22 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
weight
=
self
.
weight
[
indices
]
if
self
.
weight
is
not
None
else
None
weight
=
self
.
weight
[
indices
]
if
self
.
weight
is
not
None
else
None
bias
=
self
.
bias
[
indices
]
if
self
.
bias
is
not
None
else
None
bias
=
self
.
bias
[
indices
]
if
self
.
bias
is
not
None
else
None
return
{
'weight'
:
weight
,
'bias'
:
bias
,
'normalized_shape'
:
normalized_shape
}
def
forward_with_args
(
self
,
normalized_shape
,
eps
:
float
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps'
,
'LayerNorm'
))
params_mapping
=
self
.
slice_param
(
normalized_shape
)
weight
,
bias
,
normalized_shape
=
[
params_mapping
.
get
(
name
)
for
name
in
[
'weight'
,
'bias'
,
'normalized_shape'
]
]
return
F
.
layer_norm
(
return
F
.
layer_norm
(
inputs
,
inputs
,
normalized_shape
,
normalized_shape
,
...
@@ -622,19 +682,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -622,19 +682,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
slice
(
self
.
embed_dim
*
2
,
self
.
embed_dim
*
2
+
embed_dim
)
slice
(
self
.
embed_dim
*
2
,
self
.
embed_dim
*
2
+
embed_dim
)
]
]
def
forward_with_args
(
def
slice_param
(
self
,
embed_dim
,
kdim
,
vdim
,
**
kwargs
):
self
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
kdim
:
int_or_int_dict
|
None
,
vdim
:
int_or_int_dict
|
None
,
dropout
:
float
,
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_padding_mask
:
torch
.
Tensor
|
None
=
None
,
need_weights
:
bool
=
True
,
attn_mask
:
torch
.
Tensor
|
None
=
None
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'num_heads and dropout'
,
'MultiHeadAttention'
))
# by default, kdim, vdim can be none
# by default, kdim, vdim can be none
if
kdim
is
None
:
if
kdim
is
None
:
kdim
=
embed_dim
kdim
=
embed_dim
...
@@ -643,15 +691,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -643,15 +691,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
qkv_same_embed_dim
=
kdim
==
embed_dim
and
vdim
==
embed_dim
qkv_same_embed_dim
=
kdim
==
embed_dim
and
vdim
==
embed_dim
if
getattr
(
self
,
'batch_first'
,
False
):
# for backward compatibility: v1.7 doesn't have batch_first
query
,
key
,
value
=
[
x
.
transpose
(
1
,
0
)
for
x
in
(
query
,
key
,
value
)]
if
isinstance
(
embed_dim
,
dict
):
used_embed_dim
=
self
.
embed_dim
else
:
used_embed_dim
=
embed_dim
embed_dim_
=
_W
(
embed_dim
)
embed_dim_
=
_W
(
embed_dim
)
# in projection weights & biases has q, k, v weights concatenated together
# in projection weights & biases has q, k, v weights concatenated together
...
@@ -673,27 +712,84 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -673,27 +712,84 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
k_proj
=
_S
(
k_proj
)[:,
:
_W
(
kdim
)]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim_
]
v_proj
=
_S
(
cast
(
Tensor
,
self
.
v_proj_weight
))[:
embed_dim_
]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
v_proj
=
_S
(
v_proj
)[:,
:
_W
(
vdim
)]
else
:
q_proj
=
k_proj
=
v_proj
=
None
return
{
'in_proj_bias'
:
in_proj_bias
,
'in_proj_weight'
:
in_proj_weight
,
'bias_k'
:
bias_k
,
'bias_v'
:
bias_v
,
'out_proj.weight'
:
out_proj_weight
,
'out_proj.bias'
:
out_proj_bias
,
'q_proj_weight'
:
q_proj
,
'k_proj_weight'
:
k_proj
,
'v_proj_weight'
:
v_proj
,
'qkv_same_embed_dim'
:
qkv_same_embed_dim
}
def
_save_param_buff_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
kwargs
=
{
name
:
self
.
forward_argument
(
name
)
for
name
in
self
.
argument_list
}
params_mapping
=
self
.
slice_param
(
**
kwargs
,
mapping
=
True
)
for
name
,
value
in
itertools
.
chain
(
self
.
_parameters
.
items
(),
self
.
_buffers
.
items
()):
if
value
is
None
or
name
in
self
.
_non_persistent_buffers_set
:
continue
value
=
params_mapping
.
get
(
name
,
value
)
destination
[
prefix
+
name
]
=
value
if
keep_vars
else
value
.
detach
()
# params of out_proj is handled in ``MixedMultiHeadAttention`` rather than
# ``NonDynamicallyQuantizableLinear`` sub-module. We also convert it to state dict here.
for
name
in
[
"out_proj.weight"
,
"out_proj.bias"
]:
value
=
params_mapping
.
get
(
name
,
None
)
if
value
is
None
:
continue
destination
[
prefix
+
name
]
=
value
if
keep_vars
else
value
.
detach
()
def
_save_module_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
for
name
,
module
in
self
.
_modules
.
items
():
# the weights of ``NonDynamicallyQuantizableLinear`` has been handled in `_save_param_buff_to_state_dict`.
if
isinstance
(
module
,
nn
.
modules
.
linear
.
NonDynamicallyQuantizableLinear
):
continue
if
module
is
not
None
:
sub_state_dict
(
module
,
destination
=
destination
,
prefix
=
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
# The rest part is basically same as pytorch
def
forward_with_args
(
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
self
,
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
embed_dim
:
int_or_int_dict
,
num_heads
:
int
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
),
kdim
:
int_or_int_dict
|
None
,
vdim
:
int_or_int_dict
|
None
,
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
:
float
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
),
query
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
training
=
self
.
training
,
key_padding_mask
:
torch
.
Tensor
|
None
=
None
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
need_weights
:
bool
=
True
,
attn_mask
:
torch
.
Tensor
|
None
=
None
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
True
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
|
None
]:
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
num_heads
,
dropout
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'num_heads and dropout'
,
'MultiHeadAttention'
))
if
getattr
(
self
,
'batch_first'
,
False
):
# for backward compatibility: v1.7 doesn't have batch_first
query
,
key
,
value
=
[
x
.
transpose
(
1
,
0
)
for
x
in
(
query
,
key
,
value
)]
if
isinstance
(
embed_dim
,
dict
):
used_embed_dim
=
self
.
embed_dim
else
:
else
:
# Cast tensor here because of a bug in pytorch stub
used_embed_dim
=
embed_dim
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
params_mapping
=
self
.
slice_param
(
embed_dim
,
kdim
,
vdim
)
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
),
in_proj_bias
,
in_proj_weight
,
bias_k
,
bias_v
,
\
bias_k
,
bias_v
,
self
.
add_zero_attn
,
out_proj_weight
,
out_proj_bias
,
q_proj
,
k_proj
,
v_proj
,
qkv_same_embed_dim
=
[
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
),
params_mapping
.
get
(
name
)
training
=
self
.
training
,
for
name
in
[
'in_proj_bias'
,
'in_proj_weight'
,
'bias_k'
,
'bias_v'
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
'out_proj.weight'
,
'out_proj.bias'
,
'q_proj_weight'
,
'k_proj_weight'
,
attn_mask
=
attn_mask
)
'v_proj_weight'
,
'qkv_same_embed_dim'
]
]
# The rest part is basically same as pytorch
attn_output
,
attn_output_weights
=
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
used_embed_dim
,
num_heads
,
cast
(
Tensor
,
in_proj_weight
),
cast
(
Tensor
,
in_proj_bias
),
bias_k
,
bias_v
,
self
.
add_zero_attn
,
dropout
,
out_proj_weight
,
cast
(
Tensor
,
out_proj_bias
),
training
=
self
.
training
,
key_padding_mask
=
key_padding_mask
,
need_weights
=
need_weights
,
attn_mask
=
attn_mask
,
use_separate_proj_weight
=
not
qkv_same_embed_dim
,
q_proj_weight
=
q_proj
,
k_proj_weight
=
k_proj
,
v_proj_weight
=
v_proj
)
if
getattr
(
self
,
'batch_first'
,
False
):
# backward compatibility
if
getattr
(
self
,
'batch_first'
,
False
):
# backward compatibility
return
attn_output
.
transpose
(
1
,
0
),
attn_output_weights
return
attn_output
.
transpose
(
1
,
0
),
attn_output_weights
...
@@ -701,6 +797,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
...
@@ -701,6 +797,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
return
attn_output
,
attn_output_weights
return
attn_output
,
attn_output_weights
NATIVE_MIXED_OPERATIONS
:
list
[
Type
[
MixedOperation
]]
=
[
NATIVE_MIXED_OPERATIONS
:
list
[
Type
[
MixedOperation
]]
=
[
MixedLinear
,
MixedLinear
,
MixedConv2d
,
MixedConv2d
,
...
...
nni/nas/oneshot/pytorch/supermodule/sampling.py
View file @
9e2a069d
...
@@ -15,7 +15,7 @@ from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
...
@@ -15,7 +15,7 @@ from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.nas.nn.pytorch.choice
import
ValueChoiceX
from
nni.nas.nn.pytorch.cell
import
CellOpFactory
,
create_cell_op_candidates
,
preprocess_cell_inputs
from
nni.nas.nn.pytorch.cell
import
CellOpFactory
,
create_cell_op_candidates
,
preprocess_cell_inputs
from
.base
import
BaseSuperNetModule
from
.base
import
BaseSuperNetModule
,
sub_state_dict
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
,
dedup_inner_choices
,
weighted_sum
from
._valuechoice_utils
import
evaluate_value_choice_with_dict
,
dedup_inner_choices
,
weighted_sum
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
from
.operation
import
MixedOperationSamplingPolicy
,
MixedOperation
...
@@ -76,6 +76,14 @@ class PathSamplingLayer(BaseSuperNetModule):
...
@@ -76,6 +76,14 @@ class PathSamplingLayer(BaseSuperNetModule):
"""Override this to implement customized reduction."""
"""Override this to implement customized reduction."""
return
weighted_sum
(
items
)
return
weighted_sum
(
items
)
def
_save_module_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
sampled
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
for
samp
in
sampled
:
module
=
getattr
(
self
,
str
(
samp
))
if
module
is
not
None
:
sub_state_dict
(
module
,
destination
=
destination
,
prefix
=
prefix
,
keep_vars
=
keep_vars
)
def
forward
(
self
,
*
args
,
**
kwargs
):
def
forward
(
self
,
*
args
,
**
kwargs
):
if
self
.
_sampled
is
None
:
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
raise
RuntimeError
(
'At least one path needs to be sampled before fprop.'
)
...
@@ -229,7 +237,7 @@ class PathSamplingRepeat(BaseSuperNetModule):
...
@@ -229,7 +237,7 @@ class PathSamplingRepeat(BaseSuperNetModule):
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
]):
def
__init__
(
self
,
blocks
:
list
[
nn
.
Module
],
depth
:
ChoiceOf
[
int
]):
super
().
__init__
()
super
().
__init__
()
self
.
blocks
=
blocks
self
.
blocks
:
Any
=
blocks
self
.
depth
=
depth
self
.
depth
=
depth
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
([
depth
])
self
.
_space_spec
:
dict
[
str
,
ParameterSpec
]
=
dedup_inner_choices
([
depth
])
self
.
_sampled
:
list
[
int
]
|
int
|
None
=
None
self
.
_sampled
:
list
[
int
]
|
int
|
None
=
None
...
@@ -268,6 +276,15 @@ class PathSamplingRepeat(BaseSuperNetModule):
...
@@ -268,6 +276,15 @@ class PathSamplingRepeat(BaseSuperNetModule):
"""Override this to implement customized reduction."""
"""Override this to implement customized reduction."""
return
weighted_sum
(
items
)
return
weighted_sum
(
items
)
def
_save_module_to_state_dict
(
self
,
destination
,
prefix
,
keep_vars
):
sampled
:
Any
=
[
self
.
_sampled
]
if
not
isinstance
(
self
.
_sampled
,
list
)
else
self
.
_sampled
for
cur_depth
,
(
name
,
module
)
in
enumerate
(
self
.
blocks
.
named_children
(),
start
=
1
):
if
module
is
not
None
:
sub_state_dict
(
module
,
destination
=
destination
,
prefix
=
prefix
+
name
+
'.'
,
keep_vars
=
keep_vars
)
if
not
any
(
d
>
cur_depth
for
d
in
sampled
):
break
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
if
self
.
_sampled
is
None
:
if
self
.
_sampled
is
None
:
raise
RuntimeError
(
'At least one depth needs to be sampled before fprop.'
)
raise
RuntimeError
(
'At least one depth needs to be sampled before fprop.'
)
...
...
test/algo/nas/test_oneshot.py
View file @
9e2a069d
...
@@ -389,3 +389,22 @@ def test_optimizer_lr_scheduler():
...
@@ -389,3 +389,22 @@ def test_optimizer_lr_scheduler():
assert
len
(
learning_rates
)
==
10
and
abs
(
learning_rates
[
0
]
-
0.1
)
<
1e-5
and
\
assert
len
(
learning_rates
)
==
10
and
abs
(
learning_rates
[
0
]
-
0.1
)
<
1e-5
and
\
abs
(
learning_rates
[
2
]
-
0.01
)
<
1e-5
and
abs
(
learning_rates
[
-
1
]
-
1e-5
)
<
1e-6
abs
(
learning_rates
[
2
]
-
0.01
)
<
1e-5
and
abs
(
learning_rates
[
-
1
]
-
1e-5
)
<
1e-6
def
test_one_shot_sub_state_dict
():
from
nni.nas.strategy
import
RandomOneShot
from
nni.nas
import
fixed_arch
init_kwargs
=
{}
x
=
torch
.
rand
(
1
,
1
,
28
,
28
)
for
model_space_cls
in
[
SimpleNet
,
ValueChoiceConvNet
,
RepeatNet
]:
strategy
=
RandomOneShot
()
model_space
=
model_space_cls
()
strategy
.
attach_model
(
model_space
)
arch
=
strategy
.
model
.
resample
()
with
fixed_arch
(
arch
):
model
=
model_space_cls
(
**
init_kwargs
)
model
.
load_state_dict
(
strategy
.
sub_state_dict
(
arch
))
model
.
eval
()
model_space
.
eval
()
assert
torch
.
allclose
(
model
(
x
),
strategy
.
model
(
x
))
test/algo/nas/test_oneshot_supermodules.py
View file @
9e2a069d
...
@@ -154,16 +154,28 @@ def test_differentiable_layerchoice_dedup():
...
@@ -154,16 +154,28 @@ def test_differentiable_layerchoice_dedup():
assert
len
(
memo
)
==
1
and
'a'
in
memo
assert
len
(
memo
)
==
1
and
'a'
in
memo
def
_m
ixed_operation_sampling_sanity_check
(
operation
,
memo
,
*
input
):
def
_m
utate_op_path_sampling_policy
(
operation
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
if
native_op
.
bound_type
==
type
(
operation
):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
break
break
return
mutate_op
def
_mixed_operation_sampling_sanity_check
(
operation
,
memo
,
*
input
):
mutate_op
=
_mutate_op_path_sampling_policy
(
operation
)
mutate_op
.
resample
(
memo
=
memo
)
mutate_op
.
resample
(
memo
=
memo
)
return
mutate_op
(
*
input
)
return
mutate_op
(
*
input
)
from
nni.nas.oneshot.pytorch.supermodule.base
import
sub_state_dict
def
_mixed_operation_state_dict_sanity_check
(
operation
,
model
,
memo
,
*
input
):
mutate_op
=
_mutate_op_path_sampling_policy
(
operation
)
mutate_op
.
resample
(
memo
=
memo
)
model
.
load_state_dict
(
sub_state_dict
(
mutate_op
))
return
mutate_op
(
*
input
),
model
(
*
input
)
def
_mixed_operation_differentiable_sanity_check
(
operation
,
*
input
):
def
_mixed_operation_differentiable_sanity_check
(
operation
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
if
native_op
.
bound_type
==
type
(
operation
):
...
@@ -188,6 +200,11 @@ def test_mixed_linear():
...
@@ -188,6 +200,11 @@ def test_mixed_linear():
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]),
bias
=
ValueChoice
([
False
,
True
]))
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]),
bias
=
ValueChoice
([
False
,
True
]))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in_features'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out_features'
),
bias
=
True
)
kwargs
=
{
'in_features'
:
6
,
'out_features'
:
4
}
out1
,
out2
=
_mixed_operation_state_dict_sanity_check
(
linear
,
Linear
(
**
kwargs
),
kwargs
,
torch
.
randn
(
2
,
6
))
assert
torch
.
allclose
(
out1
,
out2
)
def
test_mixed_conv2d
():
def
test_mixed_conv2d
():
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out'
)
*
2
,
1
)
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out'
)
*
2
,
1
)
...
@@ -235,6 +252,17 @@ def test_mixed_conv2d():
...
@@ -235,6 +252,17 @@ def test_mixed_conv2d():
conv
.
resample
({
'k'
:
1
})
conv
.
resample
({
'k'
:
1
})
assert
conv
(
torch
.
ones
((
1
,
1
,
3
,
3
))).
sum
().
item
()
==
9
assert
conv
(
torch
.
ones
((
1
,
1
,
3
,
3
))).
sum
().
item
()
==
9
# only `in_channels`, `out_channels`, `kernel_size`, and `groups` influence state_dict
conv
=
Conv2d
(
ValueChoice
([
2
,
4
,
8
],
label
=
'in_channels'
),
ValueChoice
([
6
,
12
,
24
],
label
=
'out_channels'
),
kernel_size
=
ValueChoice
([
3
,
5
,
7
],
label
=
'kernel_size'
),
groups
=
ValueChoice
([
1
,
2
],
label
=
'groups'
)
)
kwargs
=
{
'in_channels'
:
8
,
'out_channels'
:
12
,
'kernel_size'
:
5
,
'groups'
:
2
}
out1
,
out2
=
_mixed_operation_state_dict_sanity_check
(
conv
,
Conv2d
(
**
kwargs
),
kwargs
,
torch
.
randn
(
2
,
8
,
16
,
16
))
assert
torch
.
allclose
(
out1
,
out2
)
def
test_mixed_batchnorm2d
():
def
test_mixed_batchnorm2d
():
bn
=
BatchNorm2d
(
ValueChoice
([
32
,
64
],
label
=
'dim'
))
bn
=
BatchNorm2d
(
ValueChoice
([
32
,
64
],
label
=
'dim'
))
...
@@ -244,6 +272,10 @@ def test_mixed_batchnorm2d():
...
@@ -244,6 +272,10 @@ def test_mixed_batchnorm2d():
_mixed_operation_differentiable_sanity_check
(
bn
,
torch
.
randn
(
2
,
64
,
3
,
3
))
_mixed_operation_differentiable_sanity_check
(
bn
,
torch
.
randn
(
2
,
64
,
3
,
3
))
bn
=
BatchNorm2d
(
ValueChoice
([
32
,
48
,
64
],
label
=
'num_features'
))
kwargs
=
{
'num_features'
:
48
}
out1
,
out2
=
_mixed_operation_state_dict_sanity_check
(
bn
,
BatchNorm2d
(
**
kwargs
),
kwargs
,
torch
.
randn
(
2
,
48
,
3
,
3
))
assert
torch
.
allclose
(
out1
,
out2
)
def
test_mixed_layernorm
():
def
test_mixed_layernorm
():
ln
=
LayerNorm
(
ValueChoice
([
32
,
64
],
label
=
'normalized_shape'
),
elementwise_affine
=
True
)
ln
=
LayerNorm
(
ValueChoice
([
32
,
64
],
label
=
'normalized_shape'
),
elementwise_affine
=
True
)
...
@@ -261,6 +293,10 @@ def test_mixed_layernorm():
...
@@ -261,6 +293,10 @@ def test_mixed_layernorm():
_mixed_operation_differentiable_sanity_check
(
ln
,
torch
.
randn
(
2
,
64
,
16
))
_mixed_operation_differentiable_sanity_check
(
ln
,
torch
.
randn
(
2
,
64
,
16
))
ln
=
LayerNorm
(
ValueChoice
([
32
,
48
,
64
],
label
=
'normalized_shape'
))
kwargs
=
{
'normalized_shape'
:
48
}
out1
,
out2
=
_mixed_operation_state_dict_sanity_check
(
ln
,
LayerNorm
(
**
kwargs
),
kwargs
,
torch
.
randn
(
2
,
8
,
48
))
assert
torch
.
allclose
(
out1
,
out2
)
def
test_mixed_mhattn
():
def
test_mixed_mhattn
():
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
)
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
)
...
@@ -293,6 +329,11 @@ def test_mixed_mhattn():
...
@@ -293,6 +329,11 @@ def test_mixed_mhattn():
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
))
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
))
mhattn
=
MultiheadAttention
(
embed_dim
=
ValueChoice
([
4
,
8
,
16
],
label
=
'embed_dim'
),
num_heads
=
ValueChoice
([
1
,
2
,
4
],
label
=
'num_heads'
),
kdim
=
ValueChoice
([
4
,
8
,
16
],
label
=
'kdim'
),
vdim
=
ValueChoice
([
4
,
8
,
16
],
label
=
'vdim'
))
kwargs
=
{
'embed_dim'
:
16
,
'num_heads'
:
2
,
'kdim'
:
4
,
'vdim'
:
8
}
(
out1
,
_
),
(
out2
,
_
)
=
_mixed_operation_state_dict_sanity_check
(
mhattn
,
MultiheadAttention
(
**
kwargs
),
kwargs
,
torch
.
randn
(
7
,
2
,
16
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
8
))
assert
torch
.
allclose
(
out1
,
out2
)
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
startswith
(
'1.7'
),
reason
=
'batch_first is not supported for legacy PyTorch'
)
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
startswith
(
'1.7'
),
reason
=
'batch_first is not supported for legacy PyTorch'
)
def
test_mixed_mhattn_batch_first
():
def
test_mixed_mhattn_batch_first
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment