Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
481aa292
Unverified
Commit
481aa292
authored
Jul 14, 2022
by
Maze
Committed by
GitHub
Jul 14, 2022
Browse files
Fix Autoformer to compatible with RandomOneShot strategy (#4987)
parent
5a3d82e8
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
389 additions
and
175 deletions
+389
-175
nni/retiarii/hub/pytorch/autoformer.py
nni/retiarii/hub/pytorch/autoformer.py
+279
-168
nni/retiarii/hub/pytorch/utils/pretrained.py
nni/retiarii/hub/pytorch/utils/pretrained.py
+5
-0
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
+4
-4
nni/retiarii/oneshot/pytorch/supermodule/operation.py
nni/retiarii/oneshot/pytorch/supermodule/operation.py
+70
-0
test/algo/nas/test_oneshot_supermodules.py
test/algo/nas/test_oneshot_supermodules.py
+24
-1
test/algo/nas/test_space_hub_oneshot.py
test/algo/nas/test_space_hub_oneshot.py
+7
-2
No files found.
nni/retiarii/hub/pytorch/autoformer.py
View file @
481aa292
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
itertools
from
typing
import
Optional
,
Tuple
,
cast
,
Any
,
Dict
from
typing
import
Optional
,
Tuple
,
cast
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
timm.models.layers
import
trunc_normal_
,
DropPath
from
timm.models.layers
import
trunc_normal_
,
DropPath
import
nni.retiarii.nn.pytorch
as
nn
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
model_wrapper
from
nni.retiarii
import
model_wrapper
,
basic_unit
from
nni.retiarii.nn.pytorch.api
import
ValueChoiceX
from
nni.retiarii.oneshot.pytorch.supermodule.operation
import
MixedOperation
from
nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils
import
traverse_all_options
from
nni.retiarii.oneshot.pytorch.supermodule._operation_utils
import
Slicable
as
_S
,
MaybeWeighted
as
_W
from
.utils.fixed
import
FixedFactory
from
.utils.pretrained
import
load_pretrained_weight
class
RelativePosition2D
(
nn
.
Module
):
class
RelativePosition2D
(
nn
.
Module
):
...
@@ -16,10 +23,8 @@ class RelativePosition2D(nn.Module):
...
@@ -16,10 +23,8 @@ class RelativePosition2D(nn.Module):
super
().
__init__
()
super
().
__init__
()
self
.
head_embed_dim
=
head_embed_dim
self
.
head_embed_dim
=
head_embed_dim
self
.
legnth
=
length
self
.
legnth
=
length
self
.
embeddings_table_v
=
nn
.
Parameter
(
self
.
embeddings_table_v
=
nn
.
Parameter
(
torch
.
randn
(
length
*
2
+
2
,
head_embed_dim
))
torch
.
randn
(
length
*
2
+
2
,
head_embed_dim
))
self
.
embeddings_table_h
=
nn
.
Parameter
(
torch
.
randn
(
length
*
2
+
2
,
head_embed_dim
))
self
.
embeddings_table_h
=
nn
.
Parameter
(
torch
.
randn
(
length
*
2
+
2
,
head_embed_dim
))
trunc_normal_
(
self
.
embeddings_table_v
,
std
=
.
02
)
trunc_normal_
(
self
.
embeddings_table_v
,
std
=
.
02
)
trunc_normal_
(
self
.
embeddings_table_h
,
std
=
.
02
)
trunc_normal_
(
self
.
embeddings_table_h
,
std
=
.
02
)
...
@@ -28,48 +33,31 @@ class RelativePosition2D(nn.Module):
...
@@ -28,48 +33,31 @@ class RelativePosition2D(nn.Module):
# remove the first cls token distance computation
# remove the first cls token distance computation
length_q
=
length_q
-
1
length_q
=
length_q
-
1
length_k
=
length_k
-
1
length_k
=
length_k
-
1
range_vec_q
=
torch
.
arange
(
length_q
)
# init in the device directly, rather than move to device
range_vec_k
=
torch
.
arange
(
length_k
)
range_vec_q
=
torch
.
arange
(
length_q
,
device
=
self
.
embeddings_table_v
.
device
)
range_vec_k
=
torch
.
arange
(
length_k
,
device
=
self
.
embeddings_table_v
.
device
)
# compute the row and column distance
# compute the row and column distance
distance_mat_v
=
(
range_vec_k
[
None
,
:]
//
length_q_sqrt
=
int
(
length_q
**
0.5
)
int
(
length_q
**
0.5
)
-
distance_mat_v
=
(
range_vec_k
[
None
,
:]
//
length_q_sqrt
-
range_vec_q
[:,
None
]
//
length_q_sqrt
)
range_vec_q
[:,
None
]
//
distance_mat_h
=
(
range_vec_k
[
None
,
:]
%
length_q_sqrt
-
range_vec_q
[:,
None
]
%
length_q_sqrt
)
int
(
length_q
**
0.5
))
distance_mat_h
=
(
range_vec_k
[
None
,
:]
%
int
(
length_q
**
0.5
)
-
range_vec_q
[:,
None
]
%
int
(
length_q
**
0.5
))
# clip the distance to the range of [-legnth, legnth]
# clip the distance to the range of [-legnth, legnth]
distance_mat_clipped_v
=
torch
.
clamp
(
distance_mat_clipped_v
=
torch
.
clamp
(
distance_mat_v
,
-
self
.
legnth
,
self
.
legnth
)
distance_mat_v
,
-
self
.
legnth
,
self
.
legnth
)
distance_mat_clipped_h
=
torch
.
clamp
(
distance_mat_h
,
-
self
.
legnth
,
self
.
legnth
)
distance_mat_clipped_h
=
torch
.
clamp
(
distance_mat_h
,
-
self
.
legnth
,
self
.
legnth
)
# translate the distance from [1, 2 * legnth + 1], 0 is for the cls
# translate the distance from [1, 2 * legnth + 1], 0 is for the cls token
# token
final_mat_v
=
distance_mat_clipped_v
+
self
.
legnth
+
1
final_mat_v
=
distance_mat_clipped_v
+
self
.
legnth
+
1
final_mat_h
=
distance_mat_clipped_h
+
self
.
legnth
+
1
final_mat_h
=
distance_mat_clipped_h
+
self
.
legnth
+
1
# pad the 0 which represent the cls token
# pad the 0 which represent the cls token
final_mat_v
=
F
.
pad
(
final_mat_v
=
F
.
pad
(
final_mat_v
,
(
1
,
0
,
1
,
0
),
"constant"
,
0
)
final_mat_v
,
(
1
,
0
,
1
,
0
),
"constant"
,
0
)
final_mat_h
=
F
.
pad
(
final_mat_h
,
(
1
,
0
,
1
,
0
),
"constant"
,
0
)
final_mat_h
=
F
.
pad
(
final_mat_h
,
(
1
,
0
,
1
,
0
),
"constant"
,
0
)
final_mat_v
=
final_mat_v
.
long
()
final_mat_h
=
final_mat_h
.
long
()
final_mat_v
=
torch
.
tensor
(
final_mat_v
,
dtype
=
torch
.
long
,
device
=
self
.
embeddings_table_v
.
device
)
final_mat_h
=
torch
.
tensor
(
final_mat_h
,
dtype
=
torch
.
long
,
device
=
self
.
embeddings_table_v
.
device
)
# get the embeddings with the corresponding distance
# get the embeddings with the corresponding distance
embeddings
=
self
.
embeddings_table_v
[
final_mat_v
]
+
\
embeddings
=
self
.
embeddings_table_v
[
final_mat_v
]
+
self
.
embeddings_table_h
[
final_mat_h
]
self
.
embeddings_table_h
[
final_mat_h
]
return
embeddings
return
embeddings
class
RelativePositionAttention
(
nn
.
Module
):
class
RelativePositionAttention
(
nn
.
Module
):
"""
"""
This class is designed to support the relative position in attention.
This class is designed to support the relative position in attention.
...
@@ -80,61 +68,62 @@ class RelativePositionAttention(nn.Module):
...
@@ -80,61 +68,62 @@ class RelativePositionAttention(nn.Module):
and keys in self-attention modules.
and keys in self-attention modules.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
embed_dim
,
num_heads
,
embed_dim
,
attn_drop
=
0.
,
proj_drop
=
0.
,
fixed_embed_dim
,
qkv_bias
=
False
,
qk_scale
=
None
,
num_heads
,
rpe_length
=
14
,
rpe
=
False
,
attn_drop
=
0.
,
head_dim
=
64
):
proj_drop
=
0
,
rpe
=
False
,
qkv_bias
=
False
,
qk_scale
=
None
,
rpe_length
=
14
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
head_dim
=
embed_dim
//
num_heads
# head_dim is fixed 64 in official autoformer. set head_dim = None to use flex head dim.
self
.
head_dim
=
head_dim
or
(
embed_dim
//
num_heads
)
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
embed_dim
,
embed_dim
*
3
,
bias
=
qkv_bias
)
# Please refer to MixedMultiheadAttention for details.
self
.
q
=
nn
.
Linear
(
embed_dim
,
head_dim
*
num_heads
,
bias
=
qkv_bias
)
self
.
k
=
nn
.
Linear
(
embed_dim
,
head_dim
*
num_heads
,
bias
=
qkv_bias
)
self
.
v
=
nn
.
Linear
(
embed_dim
,
head_dim
*
num_heads
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
embe
d_dim
,
embed_dim
)
self
.
proj
=
nn
.
Linear
(
hea
d_dim
*
num_heads
,
embed_dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
rpe
=
rpe
self
.
rpe
=
rpe
if
rpe
:
if
rpe
:
self
.
rel_pos_embed_k
=
RelativePosition2D
(
self
.
rel_pos_embed_k
=
RelativePosition2D
(
head_dim
,
rpe_length
)
fixed_embed_dim
//
num_heads
,
rpe_length
)
self
.
rel_pos_embed_v
=
RelativePosition2D
(
head_dim
,
rpe_length
)
self
.
rel_pos_embed_v
=
RelativePosition2D
(
fixed_embed_dim
//
num_heads
,
rpe_length
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
,
N
,
C
=
x
.
shape
B
,
N
,
_
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
head_dim
=
self
.
head_dim
2
,
0
,
3
,
1
,
4
)
# num_heads can not get from self.num_heads directly,
# make torchscript happy (cannot use tensor as tuple)
# use -1 to compute implicitly.
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
num_heads
=
-
1
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
num_heads
,
head_dim
).
permute
(
0
,
2
,
1
,
3
)
k
=
self
.
k
(
x
).
reshape
(
B
,
N
,
num_heads
,
head_dim
).
permute
(
0
,
2
,
1
,
3
)
v
=
self
.
v
(
x
).
reshape
(
B
,
N
,
num_heads
,
head_dim
).
permute
(
0
,
2
,
1
,
3
)
num_heads
=
q
.
size
(
1
)
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
if
self
.
rpe
:
if
self
.
rpe
:
r_p_k
=
self
.
rel_pos_embed_k
(
N
,
N
)
r_p_k
=
self
.
rel_pos_embed_k
(
N
,
N
)
attn
=
attn
+
(
attn
=
attn
+
(
q
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
q
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
N
,
num_heads
*
B
,
head_dim
)
@
r_p_k
.
transpose
(
2
,
1
)
N
,
self
.
num_heads
*
B
,
-
1
)
@
r_p_k
.
transpose
(
).
transpose
(
1
,
0
).
reshape
(
B
,
num_heads
,
N
,
N
)
*
self
.
scale
2
,
1
))
.
transpose
(
1
,
0
).
reshape
(
B
,
self
.
num_heads
,
N
,
N
)
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
num_heads
*
head_dim
)
if
self
.
rpe
:
if
self
.
rpe
:
attn_1
=
attn
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
N
,
B
*
num_heads
,
N
)
r_p_v
=
self
.
rel_pos_embed_v
(
N
,
N
)
r_p_v
=
self
.
rel_pos_embed_v
(
N
,
N
)
attn_1
=
attn
.
permute
(
2
,
0
,
1
,
3
).
reshape
(
N
,
B
*
self
.
num_heads
,
-
1
)
# The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
# The size of attention is (B, num_heads, N, N), reshape it to (N, B*num_heads, N) and do batch matmul with
# the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
# the relative position embedding of V (N, N, head_dim) get shape like (N, B*num_heads, head_dim). We reshape it to the
# same size as x (B, num_heads, N, hidden_dim)
# same size as x (B, num_heads, N, hidden_dim)
x
=
x
+
(
attn_1
@
r_p_v
).
transpose
(
1
,
0
).
reshape
(
B
,
x
=
x
+
(
attn_1
@
r_p_v
).
transpose
(
1
,
0
).
reshape
(
B
,
num_heads
,
N
,
head_dim
).
transpose
(
2
,
1
).
reshape
(
B
,
N
,
num_heads
*
head_dim
)
self
.
num_heads
,
N
,
-
1
).
transpose
(
2
,
1
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
return
x
...
@@ -146,61 +135,60 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -146,61 +135,60 @@ class TransformerEncoderLayer(nn.Module):
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
embed_dim
,
num_heads
,
mlp_ratio
=
4.
,
embed_dim
,
qkv_bias
=
False
,
qk_scale
=
None
,
rpe
=
False
,
fixed_embed_dim
,
drop_rate
=
0.
,
attn_drop
=
0.
,
proj_drop
=
0.
,
drop_path
=
0.
,
num_heads
,
pre_norm
=
True
,
rpe_length
=
14
,
head_dim
=
64
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
rpe
=
False
,
drop_rate
=
0.
,
attn_drop
=
0.
,
proj_drop
=
0.
,
drop_path
=
0.
,
pre_norm
=
True
,
rpe_length
=
14
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
normalize_before
=
pre_norm
self
.
normalize_before
=
pre_norm
self
.
drop_path
=
DropPath
(
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
dropout
=
drop_rate
self
.
dropout
=
drop_rate
self
.
attn
=
RelativePositionAttention
(
self
.
attn
=
RelativePositionAttention
(
embed_dim
=
embed_dim
,
embed_dim
=
embed_dim
,
fixed_embed_dim
=
fixed_embed_dim
,
num_heads
=
num_heads
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop
,
attn_drop
=
attn_drop
,
proj_drop
=
proj_drop
,
proj_drop
=
proj_drop
,
rpe
=
rpe
,
rpe
=
rpe
,
qkv_bias
=
qkv_bias
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
qk_scale
=
qk_scale
,
rpe_length
=
rpe_length
)
rpe_length
=
rpe_length
,
head_dim
=
head_dim
)
self
.
attn_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
attn_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
ffn_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
ffn_layer_norm
=
nn
.
LayerNorm
(
embed_dim
)
self
.
activation_fn
=
nn
.
GELU
()
self
.
activation_fn
=
nn
.
GELU
()
self
.
fc1
=
nn
.
Linear
(
self
.
fc1
=
nn
.
Linear
(
cast
(
int
,
embed_dim
),
cast
(
int
,
nn
.
ValueChoice
.
to_int
(
cast
(
int
,
embed_dim
),
embed_dim
*
mlp_ratio
)))
cast
(
int
,
nn
.
ValueChoice
.
to_int
(
embed_dim
*
mlp_ratio
))
)
self
.
fc2
=
nn
.
Linear
(
self
.
fc2
=
nn
.
Linear
(
cast
(
int
,
nn
.
ValueChoice
.
to_int
(
cast
(
int
,
nn
.
ValueChoice
.
to_int
(
embed_dim
*
mlp_ratio
)),
embed_dim
*
mlp_ratio
)),
cast
(
int
,
embed_dim
)
cast
(
int
,
embed_dim
))
)
def
maybe_layer_norm
(
self
,
layer_norm
,
x
,
before
=
False
,
after
=
False
):
assert
before
^
after
if
after
^
self
.
normalize_before
:
return
layer_norm
(
x
)
else
:
return
x
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
"""
"""
Args:
Args:
x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
x (Tensor): input to the layer of shape `(batch, patch_num , sample_embed_dim)`
Returns:
Returns:
encoded output of shape `(batch, patch_num, sample_embed_dim)`
encoded output of shape `(batch, patch_num, sample_embed_dim)`
"""
"""
residual
=
x
residual
=
x
x
=
self
.
maybe_layer_norm
(
self
.
attn_layer_norm
,
x
,
before
=
True
)
x
=
self
.
maybe_layer_norm
(
self
.
attn_layer_norm
,
x
,
before
=
True
)
x
=
self
.
attn
(
x
)
x
=
self
.
attn
(
x
)
x
=
nn
.
functional
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
drop_path
(
x
)
x
=
self
.
drop_path
(
x
)
x
=
residual
+
x
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
attn_layer_norm
,
x
,
after
=
True
)
x
=
self
.
maybe_layer_norm
(
self
.
attn_layer_norm
,
x
,
after
=
True
)
...
@@ -209,20 +197,86 @@ class TransformerEncoderLayer(nn.Module):
...
@@ -209,20 +197,86 @@ class TransformerEncoderLayer(nn.Module):
x
=
self
.
maybe_layer_norm
(
self
.
ffn_layer_norm
,
x
,
before
=
True
)
x
=
self
.
maybe_layer_norm
(
self
.
ffn_layer_norm
,
x
,
before
=
True
)
x
=
self
.
fc1
(
x
)
x
=
self
.
fc1
(
x
)
x
=
self
.
activation_fn
(
x
)
x
=
self
.
activation_fn
(
x
)
x
=
nn
.
functional
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
fc2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
nn
.
functional
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
x
=
self
.
drop_path
(
x
)
x
=
self
.
drop_path
(
x
)
x
=
residual
+
x
x
=
residual
+
x
x
=
self
.
maybe_layer_norm
(
self
.
ffn_layer_norm
,
x
,
after
=
True
)
x
=
self
.
maybe_layer_norm
(
self
.
ffn_layer_norm
,
x
,
after
=
True
)
return
x
return
x
def
maybe_layer_norm
(
self
,
layer_norm
,
x
,
before
=
False
,
after
=
False
):
assert
before
^
after
@
basic_unit
if
after
^
self
.
normalize_before
:
class
ClsToken
(
nn
.
Module
):
return
layer_norm
(
x
)
""" Concat class token with dim=embed_dim before patch embedding.
else
:
"""
return
x
def
__init__
(
self
,
embed_dim
:
int
):
super
().
__init__
()
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
embed_dim
))
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
def
forward
(
self
,
x
):
return
torch
.
cat
((
self
.
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
),
x
),
dim
=
1
)
class
MixedClsToken
(
MixedOperation
,
ClsToken
):
""" Mixed class token concat operation.
Supported arguments are:
- ``embed_dim``
Prefix of cls_token will be sliced.
"""
bound_type
=
ClsToken
argument_list
=
[
'embed_dim'
]
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
embed_dim
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embed_dim_
=
_W
(
embed_dim
)
cls_token
=
_S
(
self
.
cls_token
)[...,
:
embed_dim_
]
return
torch
.
cat
((
cls_token
.
expand
(
inputs
.
shape
[
0
],
-
1
,
-
1
),
inputs
),
dim
=
1
)
@
basic_unit
class
AbsPosEmbed
(
nn
.
Module
):
""" Add absolute position embedding on patch embedding.
"""
def
__init__
(
self
,
length
:
int
,
embed_dim
:
int
):
super
().
__init__
()
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
length
,
embed_dim
))
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
def
forward
(
self
,
x
):
return
x
+
self
.
pos_embed
class
MixedAbsPosEmbed
(
MixedOperation
,
AbsPosEmbed
):
""" Mixed absolute position embedding add operation.
Supported arguments are:
- ``embed_dim``
Prefix of pos_embed will be sliced.
"""
bound_type
=
AbsPosEmbed
argument_list
=
[
'embed_dim'
]
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
return
max
(
traverse_all_options
(
value_choice
))
def
forward_with_args
(
self
,
embed_dim
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
embed_dim_
=
_W
(
embed_dim
)
pos_embed
=
_S
(
self
.
pos_embed
)[...,
:
embed_dim_
]
return
inputs
+
pos_embed
@
model_wrapper
@
model_wrapper
...
@@ -267,87 +321,144 @@ class AutoformerSpace(nn.Module):
...
@@ -267,87 +321,144 @@ class AutoformerSpace(nn.Module):
The scaler on score map in self-attention.
The scaler on score map in self-attention.
rpe : bool
rpe : bool
Whether to use relative position encoding.
Whether to use relative position encoding.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
search_embed_dim
:
Tuple
[
int
,
...]
=
(
192
,
216
,
240
),
search_embed_dim
:
Tuple
[
int
,
...]
=
(
192
,
216
,
240
),
search_mlp_ratio
:
Tuple
[
float
,
...]
=
(
3.5
,
4.0
),
search_mlp_ratio
:
Tuple
[
float
,
...]
=
(
3.0
,
3.5
,
4.0
),
search_num_heads
:
Tuple
[
int
,
...]
=
(
3
,
4
),
search_num_heads
:
Tuple
[
int
,
...]
=
(
3
,
4
),
search_depth
:
Tuple
[
int
,
...]
=
(
12
,
13
,
14
),
search_depth
:
Tuple
[
int
,
...]
=
(
12
,
13
,
14
),
img_size
:
int
=
224
,
img_size
:
int
=
224
,
patch_size
:
int
=
16
,
patch_size
:
int
=
16
,
in_chans
:
int
=
3
,
in_chans
:
int
=
3
,
num_classes
:
int
=
1000
,
num_classes
:
int
=
1000
,
qkv_bias
:
bool
=
False
,
qkv_bias
:
bool
=
False
,
drop_rate
:
float
=
0.
,
drop_rate
:
float
=
0.
,
attn_drop_rate
:
float
=
0.
,
attn_drop_rate
:
float
=
0.
,
drop_path_rate
:
float
=
0.
,
drop_path_rate
:
float
=
0.
,
pre_norm
:
bool
=
True
,
pre_norm
:
bool
=
True
,
global_pool
:
bool
=
False
,
global_pool
:
bool
=
False
,
abs_pos
:
bool
=
True
,
abs_pos
:
bool
=
True
,
qk_scale
:
Optional
[
float
]
=
None
,
qk_scale
:
Optional
[
float
]
=
None
,
rpe
:
bool
=
True
,
rpe
:
bool
=
True
,
):
):
super
().
__init__
()
super
().
__init__
()
# define search space parameters
embed_dim
=
nn
.
ValueChoice
(
list
(
search_embed_dim
),
label
=
"embed_dim"
)
embed_dim
=
nn
.
ValueChoice
(
list
(
search_embed_dim
),
label
=
"embed_dim"
)
fixed_embed_dim
=
nn
.
ModelParameterChoice
(
list
(
search_embed_dim
),
label
=
"embed_dim"
)
depth
=
nn
.
ValueChoice
(
list
(
search_depth
),
label
=
"depth"
)
depth
=
nn
.
ValueChoice
(
list
(
search_depth
),
label
=
"depth"
)
mlp_ratios
=
[
nn
.
ValueChoice
(
list
(
search_mlp_ratio
),
label
=
f
"mlp_ratio_
{
i
}
"
)
for
i
in
range
(
max
(
search_depth
))]
num_heads
=
[
nn
.
ValueChoice
(
list
(
search_num_heads
),
label
=
f
"num_head_
{
i
}
"
)
for
i
in
range
(
max
(
search_depth
))]
self
.
patch_embed
=
nn
.
Conv2d
(
self
.
patch_embed
=
nn
.
Conv2d
(
in_chans
,
in_chans
,
cast
(
int
,
embed_dim
),
cast
(
int
,
embed_dim
)
,
kernel_size
=
patch_size
,
kernel_size
=
patch_size
,
stride
=
patch_size
stride
=
patch_size
)
)
self
.
patches_num
=
int
((
img_size
//
patch_size
)
**
2
)
self
.
patches_num
=
int
((
img_size
//
patch_size
)
**
2
)
self
.
global_pool
=
global_pool
self
.
global_pool
=
global_pool
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
cast
(
int
,
fixed_embed_dim
)))
trunc_normal_
(
self
.
cls_token
,
std
=
.
02
)
dpr
=
[
self
.
cls_token
=
ClsToken
(
cast
(
int
,
embed_dim
))
x
.
item
()
for
x
in
torch
.
linspace
(
self
.
pos_embed
=
AbsPosEmbed
(
self
.
patches_num
+
1
,
cast
(
int
,
embed_dim
))
if
abs_pos
else
nn
.
Identity
()
0
,
drop_path_rate
,
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
max
(
search_depth
))]
# stochastic depth decay rule
max
(
search_depth
))]
# stochastic depth decay rule
self
.
blocks
=
nn
.
Repeat
(
self
.
abs_pos
=
abs_pos
lambda
index
:
TransformerEncoderLayer
(
if
self
.
abs_pos
:
embed_dim
=
embed_dim
,
num_heads
=
num_heads
[
index
],
mlp_ratio
=
mlp_ratios
[
index
],
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
qkv_bias
=
qkv_bias
,
drop_rate
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
index
],
1
,
self
.
patches_num
+
1
,
cast
(
int
,
fixed_embed_dim
)))
rpe_length
=
img_size
//
patch_size
,
qk_scale
=
qk_scale
,
rpe
=
rpe
,
pre_norm
=
pre_norm
,
head_dim
=
64
trunc_normal_
(
self
.
pos_embed
,
std
=
.
02
)
),
depth
)
self
.
blocks
=
nn
.
Repeat
(
lambda
index
:
nn
.
LayerChoice
([
TransformerEncoderLayer
(
embed_dim
=
embed_dim
,
self
.
norm
=
nn
.
LayerNorm
(
cast
(
int
,
embed_dim
))
if
pre_norm
else
nn
.
Identity
()
fixed_embed_dim
=
fixed_embed_dim
,
self
.
head
=
nn
.
Linear
(
cast
(
int
,
embed_dim
),
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
num_heads
=
num_heads
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
drop_rate
=
drop_rate
,
@
classmethod
attn_drop
=
attn_drop_rate
,
def
get_extra_mutation_hooks
(
cls
):
drop_path
=
dpr
[
index
],
return
[
MixedAbsPosEmbed
.
mutate
,
MixedClsToken
.
mutate
]
rpe_length
=
img_size
//
patch_size
,
qk_scale
=
qk_scale
,
rpe
=
rpe
,
@
classmethod
pre_norm
=
pre_norm
,)
def
load_searched_model
(
for
mlp_ratio
,
num_heads
in
itertools
.
product
(
search_mlp_ratio
,
search_num_heads
)
cls
,
name
:
str
,
],
label
=
f
'layer
{
index
}
'
),
depth
)
pretrained
:
bool
=
False
,
download
:
bool
=
False
,
progress
:
bool
=
True
self
.
pre_norm
=
pre_norm
)
->
nn
.
Module
:
if
self
.
pre_norm
:
self
.
norm
=
nn
.
LayerNorm
(
cast
(
int
,
embed_dim
))
init_kwargs
=
{
'qkv_bias'
:
True
,
'drop_rate'
:
0.0
,
'drop_path_rate'
:
0.1
,
'global_pool'
:
True
,
'num_classes'
:
1000
}
self
.
head
=
nn
.
Linear
(
if
name
==
'autoformer-tiny'
:
cast
(
int
,
embed_dim
),
mlp_ratio
=
[
3.5
,
3.5
,
3.0
,
3.5
,
3.0
,
3.0
,
4.0
,
4.0
,
3.5
,
4.0
,
3.5
,
4.0
,
3.5
]
+
[
3.0
]
num_classes
)
if
num_classes
>
0
else
nn
.
Identity
()
num_head
=
[
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
3
,
4
,
3
,
3
]
+
[
3
]
arch
:
Dict
[
str
,
Any
]
=
{
'embed_dim'
:
192
,
'depth'
:
13
}
for
i
in
range
(
14
):
arch
[
f
'mlp_ratio_
{
i
}
'
]
=
mlp_ratio
[
i
]
arch
[
f
'num_head_
{
i
}
'
]
=
num_head
[
i
]
init_kwargs
.
update
({
'search_embed_dim'
:
(
240
,
216
,
192
),
'search_mlp_ratio'
:
(
4.0
,
3.5
,
3.0
),
'search_num_heads'
:
(
4
,
3
),
'search_depth'
:
(
14
,
13
,
12
),
})
elif
name
==
'autoformer-small'
:
mlp_ratio
=
[
3.0
,
3.5
,
3.0
,
3.5
,
4.0
,
4.0
,
4.0
,
4.0
,
4.0
,
4.0
,
4.0
,
3.5
,
4.0
]
+
[
3.0
]
num_head
=
[
6
,
6
,
5
,
7
,
5
,
5
,
5
,
6
,
6
,
7
,
7
,
6
,
7
]
+
[
5
]
arch
:
Dict
[
str
,
Any
]
=
{
'embed_dim'
:
384
,
'depth'
:
13
}
for
i
in
range
(
14
):
arch
[
f
'mlp_ratio_
{
i
}
'
]
=
mlp_ratio
[
i
]
arch
[
f
'num_head_
{
i
}
'
]
=
num_head
[
i
]
init_kwargs
.
update
({
'search_embed_dim'
:
(
448
,
384
,
320
),
'search_mlp_ratio'
:
(
4.0
,
3.5
,
3.0
),
'search_num_heads'
:
(
7
,
6
,
5
),
'search_depth'
:
(
14
,
13
,
12
),
})
elif
name
==
'autoformer-base'
:
mlp_ratio
=
[
3.5
,
3.5
,
4.0
,
3.5
,
4.0
,
3.5
,
3.5
,
3.0
,
4.0
,
4.0
,
3.0
,
4.0
,
3.0
,
3.5
]
+
[
3.0
,
3.0
]
num_head
=
[
9
,
9
,
9
,
9
,
9
,
10
,
9
,
9
,
10
,
9
,
10
,
9
,
9
,
10
]
+
[
8
,
8
]
arch
:
Dict
[
str
,
Any
]
=
{
'embed_dim'
:
576
,
'depth'
:
14
}
for
i
in
range
(
16
):
arch
[
f
'mlp_ratio_
{
i
}
'
]
=
mlp_ratio
[
i
]
arch
[
f
'num_head_
{
i
}
'
]
=
num_head
[
i
]
init_kwargs
.
update
({
'search_embed_dim'
:
(
624
,
576
,
528
),
'search_mlp_ratio'
:
(
4.0
,
3.5
,
3.0
),
'search_num_heads'
:
(
10
,
9
,
8
),
'search_depth'
:
(
16
,
15
,
14
),
})
else
:
raise
ValueError
(
f
'Unsupported architecture with name:
{
name
}
'
)
model_factory
=
FixedFactory
(
cls
,
arch
)
model
=
model_factory
(
**
init_kwargs
)
if
pretrained
:
weight_file
=
load_pretrained_weight
(
name
,
download
=
download
,
progress
=
progress
)
pretrained_weights
=
torch
.
load
(
weight_file
)
model
.
load_state_dict
(
pretrained_weights
)
return
model
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
B
=
x
.
shape
[
0
]
B
=
x
.
shape
[
0
]
x
=
self
.
patch_embed
(
x
)
x
=
self
.
patch_embed
(
x
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
view
(
B
,
self
.
patches_num
,
-
1
)
x
=
x
.
permute
(
0
,
2
,
3
,
1
).
view
(
B
,
self
.
patches_num
,
-
1
)
cls_tokens
=
self
.
cls_token
.
expand
(
B
,
-
1
,
-
1
)
x
=
self
.
cls_token
(
x
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
x
=
self
.
pos_embed
(
x
)
if
self
.
abs_pos
:
x
=
x
+
self
.
pos_embed
x
=
self
.
blocks
(
x
)
x
=
self
.
blocks
(
x
)
if
self
.
pre_
norm
:
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
if
self
.
global_pool
:
if
self
.
global_pool
:
x
=
torch
.
mean
(
x
[:,
1
:],
dim
=
1
)
x
=
torch
.
mean
(
x
[:,
1
:],
dim
=
1
)
else
:
else
:
...
...
nni/retiarii/hub/pytorch/utils/pretrained.py
View file @
481aa292
...
@@ -37,6 +37,11 @@ PRETRAINED_WEIGHT_URLS = {
...
@@ -37,6 +37,11 @@ PRETRAINED_WEIGHT_URLS = {
# spos
# spos
'spos'
:
f
'
{
NNI_BLOB
}
/nashub/spos-0b17f6fc.pth'
,
'spos'
:
f
'
{
NNI_BLOB
}
/nashub/spos-0b17f6fc.pth'
,
# autoformer
'autoformer-tiny'
:
f
'
{
NNI_BLOB
}
/nashub/autoformer-searched-tiny-1e90ebc1.pth'
,
'autoformer-small'
:
f
'
{
NNI_BLOB
}
/nashub/autoformer-searched-small-4bc5d4e5.pth'
,
'autoformer-base'
:
f
'
{
NNI_BLOB
}
/nashub/autoformer-searched-base-c417590a.pth'
}
}
...
...
nni/retiarii/oneshot/pytorch/supermodule/_operation_utils.py
View file @
481aa292
...
@@ -140,7 +140,7 @@ class Slicable(Generic[T]):
...
@@ -140,7 +140,7 @@ class Slicable(Generic[T]):
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
raise
TypeError
(
f
'Unsuppoted weight type:
{
type
(
weight
)
}
'
)
self
.
weight
=
weight
self
.
weight
=
weight
def
__getitem__
(
self
,
index
:
slice_type
|
multidim_slice
)
->
T
:
def
__getitem__
(
self
,
index
:
slice_type
|
multidim_slice
|
Any
)
->
T
:
if
not
isinstance
(
index
,
tuple
):
if
not
isinstance
(
index
,
tuple
):
index
=
(
index
,
)
index
=
(
index
,
)
index
=
cast
(
multidim_slice
,
index
)
index
=
cast
(
multidim_slice
,
index
)
...
@@ -267,7 +267,7 @@ def _iterate_over_slice_type(s: slice_type):
...
@@ -267,7 +267,7 @@ def _iterate_over_slice_type(s: slice_type):
def
_iterate_over_multidim_slice
(
ms
:
multidim_slice
):
def
_iterate_over_multidim_slice
(
ms
:
multidim_slice
):
"""Get :class:`MaybeWeighted` instances in ``ms``."""
"""Get :class:`MaybeWeighted` instances in ``ms``."""
for
s
in
ms
:
for
s
in
ms
:
if
s
is
not
None
:
if
s
is
not
None
and
s
is
not
Ellipsis
:
yield
from
_iterate_over_slice_type
(
s
)
yield
from
_iterate_over_slice_type
(
s
)
...
@@ -286,8 +286,8 @@ def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None
...
@@ -286,8 +286,8 @@ def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res
=
[]
res
=
[]
for
s
in
ms
:
for
s
in
ms
:
if
s
is
not
None
:
if
s
is
not
None
and
s
is
not
Ellipsis
:
res
.
append
(
_evaluate_slice_type
(
s
,
value_fn
))
res
.
append
(
_evaluate_slice_type
(
s
,
value_fn
))
else
:
else
:
res
.
append
(
None
)
res
.
append
(
s
)
return
tuple
(
res
)
return
tuple
(
res
)
nni/retiarii/oneshot/pytorch/supermodule/operation.py
View file @
481aa292
...
@@ -35,6 +35,7 @@ __all__ = [
...
@@ -35,6 +35,7 @@ __all__ = [
'MixedLinear'
,
'MixedLinear'
,
'MixedConv2d'
,
'MixedConv2d'
,
'MixedBatchNorm2d'
,
'MixedBatchNorm2d'
,
'MixedLayerNorm'
,
'MixedMultiHeadAttention'
,
'MixedMultiHeadAttention'
,
'NATIVE_MIXED_OPERATIONS'
,
'NATIVE_MIXED_OPERATIONS'
,
]
]
...
@@ -472,6 +473,74 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
...
@@ -472,6 +473,74 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
eps
,
eps
,
)
)
class
MixedLayerNorm
(
MixedOperation
,
nn
.
LayerNorm
):
"""
Mixed LayerNorm operation.
Supported arguments are:
- ``normalized_shape``
- ``eps`` (only supported in path sampling)
For path-sampling, prefix of ``weight`` and ``bias`` are sliced.
For weighted cases, the maximum ``normalized_shape`` is used directly.
eps is required to be float.
"""
bound_type
=
retiarii_nn
.
LayerNorm
argument_list
=
[
'normalized_shape'
,
'eps'
]
@
staticmethod
def
_to_tuple
(
value
:
scalar_or_scalar_dict
[
Any
])
->
tuple
[
Any
,
Any
]:
if
not
isinstance
(
value
,
tuple
):
return
(
value
,
value
)
return
value
def
super_init_argument
(
self
,
name
:
str
,
value_choice
:
ValueChoiceX
):
if
name
not
in
[
'normalized_shape'
]:
raise
NotImplementedError
(
f
'Unsupported value choice on argument:
{
name
}
'
)
all_sizes
=
set
(
traverse_all_options
(
value_choice
))
if
any
(
isinstance
(
sz
,
(
tuple
,
list
))
for
sz
in
all_sizes
):
# transpose
all_sizes
=
list
(
zip
(
*
all_sizes
))
# maximum dim should be calculated on every dimension
return
(
max
(
self
.
_to_tuple
(
sz
))
for
sz
in
all_sizes
)
else
:
return
max
(
all_sizes
)
def
forward_with_args
(
self
,
normalized_shape
,
eps
:
float
,
inputs
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
any
(
isinstance
(
arg
,
dict
)
for
arg
in
[
eps
]):
raise
ValueError
(
_diff_not_compatible_error
.
format
(
'eps'
,
'LayerNorm'
))
if
isinstance
(
normalized_shape
,
dict
):
normalized_shape
=
self
.
normalized_shape
# make it as tuple
if
isinstance
(
normalized_shape
,
int
):
normalized_shape
=
(
normalized_shape
,
)
if
isinstance
(
self
.
normalized_shape
,
int
):
normalized_shape
=
(
self
.
normalized_shape
,
)
# slice all the normalized shape
indices
=
[
slice
(
0
,
min
(
i
,
j
))
for
i
,
j
in
zip
(
normalized_shape
,
self
.
normalized_shape
)]
# remove _S(*)
weight
=
self
.
weight
[
indices
]
if
self
.
weight
is
not
None
else
None
bias
=
self
.
bias
[
indices
]
if
self
.
bias
is
not
None
else
None
return
F
.
layer_norm
(
inputs
,
normalized_shape
,
weight
,
bias
,
eps
)
class
MixedMultiHeadAttention
(
MixedOperation
,
nn
.
MultiheadAttention
):
class
MixedMultiHeadAttention
(
MixedOperation
,
nn
.
MultiheadAttention
):
"""
"""
...
@@ -628,6 +697,7 @@ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
...
@@ -628,6 +697,7 @@ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear
,
MixedLinear
,
MixedConv2d
,
MixedConv2d
,
MixedBatchNorm2d
,
MixedBatchNorm2d
,
MixedLayerNorm
,
MixedMultiHeadAttention
,
MixedMultiHeadAttention
,
]
]
...
...
test/algo/nas/test_oneshot_supermodules.py
View file @
481aa292
...
@@ -3,7 +3,7 @@ import pytest
...
@@ -3,7 +3,7 @@ import pytest
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
Linear
,
MultiheadAttention
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
LayerNorm
,
Linear
,
MultiheadAttention
from
nni.retiarii.oneshot.pytorch.base_lightning
import
traverse_and_mutate_submodules
from
nni.retiarii.oneshot.pytorch.base_lightning
import
traverse_and_mutate_submodules
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
,
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
,
...
@@ -28,6 +28,12 @@ def test_slice():
...
@@ -28,6 +28,12 @@ def test_slice():
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
,
:,
9
:
13
].
shape
==
(
3
,
6
,
24
,
4
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
,
:,
9
:
13
].
shape
==
(
3
,
6
,
24
,
4
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
].
shape
==
(
3
,
6
,
24
,
23
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
].
shape
==
(
3
,
6
,
24
,
23
)
# Ellipsis
assert
S
(
weight
)[...,
9
:
13
].
shape
==
(
3
,
7
,
24
,
4
)
assert
S
(
weight
)[:
2
,
...,
1
:
W
(
3
)
+
1
].
shape
==
(
2
,
7
,
24
,
3
)
assert
S
(
weight
)[...,
1
:
W
(
3
)
*
2
+
1
].
shape
==
(
3
,
7
,
24
,
6
)
assert
S
(
weight
)[...,
:
10
,
1
:
W
(
3
)
*
2
+
1
].
shape
==
(
3
,
7
,
10
,
6
)
# no effect
# no effect
assert
S
(
weight
)[:]
is
weight
assert
S
(
weight
)[:]
is
weight
...
@@ -227,6 +233,23 @@ def test_mixed_batchnorm2d():
...
@@ -227,6 +233,23 @@ def test_mixed_batchnorm2d():
_mixed_operation_differentiable_sanity_check
(
bn
,
torch
.
randn
(
2
,
64
,
3
,
3
))
_mixed_operation_differentiable_sanity_check
(
bn
,
torch
.
randn
(
2
,
64
,
3
,
3
))
def
test_mixed_layernorm
():
ln
=
LayerNorm
(
ValueChoice
([
32
,
64
],
label
=
'normalized_shape'
),
elementwise_affine
=
True
)
assert
_mixed_operation_sampling_sanity_check
(
ln
,
{
'normalized_shape'
:
32
},
torch
.
randn
(
2
,
16
,
32
)).
size
(
-
1
)
==
32
assert
_mixed_operation_sampling_sanity_check
(
ln
,
{
'normalized_shape'
:
64
},
torch
.
randn
(
2
,
16
,
64
)).
size
(
-
1
)
==
64
_mixed_operation_differentiable_sanity_check
(
ln
,
torch
.
randn
(
2
,
16
,
64
))
import
itertools
ln
=
LayerNorm
(
ValueChoice
(
list
(
itertools
.
product
([
16
,
32
,
64
],
[
8
,
16
])),
label
=
'normalized_shape'
))
assert
list
(
_mixed_operation_sampling_sanity_check
(
ln
,
{
'normalized_shape'
:
(
16
,
8
)},
torch
.
randn
(
2
,
16
,
8
)).
shape
[
-
2
:])
==
[
16
,
8
]
assert
list
(
_mixed_operation_sampling_sanity_check
(
ln
,
{
'normalized_shape'
:
(
64
,
16
)},
torch
.
randn
(
2
,
64
,
16
)).
shape
[
-
2
:])
==
[
64
,
16
]
_mixed_operation_differentiable_sanity_check
(
ln
,
torch
.
randn
(
2
,
64
,
16
))
def
test_mixed_mhattn
():
def
test_mixed_mhattn
():
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
)
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
)
...
...
test/algo/nas/test_space_hub_oneshot.py
View file @
481aa292
...
@@ -78,6 +78,11 @@ def _strategy_factory(alias, space_type):
...
@@ -78,6 +78,11 @@ def _strategy_factory(alias, space_type):
extra_mutation_hooks
.
append
(
NDSStagePathSampling
.
mutate
)
extra_mutation_hooks
.
append
(
NDSStagePathSampling
.
mutate
)
else
:
else
:
extra_mutation_hooks
.
append
(
NDSStageDifferentiable
.
mutate
)
extra_mutation_hooks
.
append
(
NDSStageDifferentiable
.
mutate
)
# Autoformer search space require specific extra hooks
if
space_type
==
'autoformer'
:
from
nni.retiarii.hub.pytorch.autoformer
import
MixedAbsPosEmbed
,
MixedClsToken
extra_mutation_hooks
.
extend
([
MixedAbsPosEmbed
.
mutate
,
MixedClsToken
.
mutate
])
if
alias
==
'darts'
:
if
alias
==
'darts'
:
return
stg
.
DARTS
(
mutation_hooks
=
extra_mutation_hooks
)
return
stg
.
DARTS
(
mutation_hooks
=
extra_mutation_hooks
)
...
@@ -149,7 +154,7 @@ def _dataset_factory(dataset_type, subset=20):
...
@@ -149,7 +154,7 @@ def _dataset_factory(dataset_type, subset=20):
'mobilenetv3_small'
,
'mobilenetv3_small'
,
'proxylessnas'
,
'proxylessnas'
,
'shufflenet'
,
'shufflenet'
,
#
'autoformer',
'autoformer'
,
'nasnet'
,
'nasnet'
,
'enas'
,
'enas'
,
'amoeba'
,
'amoeba'
,
...
@@ -186,7 +191,7 @@ def test_hub_oneshot(space_type, strategy_type):
...
@@ -186,7 +191,7 @@ def test_hub_oneshot(space_type, strategy_type):
NDS_SPACES
=
[
'amoeba'
,
'darts'
,
'pnas'
,
'enas'
,
'nasnet'
]
NDS_SPACES
=
[
'amoeba'
,
'darts'
,
'pnas'
,
'enas'
,
'nasnet'
]
if
strategy_type
==
'proxyless'
:
if
strategy_type
==
'proxyless'
:
if
'width'
in
space_type
or
'depth'
in
space_type
or
\
if
'width'
in
space_type
or
'depth'
in
space_type
or
\
any
(
space_type
.
startswith
(
prefix
)
for
prefix
in
NDS_SPACES
+
[
'proxylessnas'
,
'mobilenetv3'
]):
any
(
space_type
.
startswith
(
prefix
)
for
prefix
in
NDS_SPACES
+
[
'proxylessnas'
,
'mobilenetv3'
,
'autoformer'
]):
pytest
.
skip
(
'The space has used unsupported APIs.'
)
pytest
.
skip
(
'The space has used unsupported APIs.'
)
if
strategy_type
in
[
'darts'
,
'gumbel'
]
and
space_type
==
'mobilenetv3'
:
if
strategy_type
in
[
'darts'
,
'gumbel'
]
and
space_type
==
'mobilenetv3'
:
pytest
.
skip
(
'Skip as it consumes too much memory.'
)
pytest
.
skip
(
'Skip as it consumes too much memory.'
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment