Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
8547b21c
Unverified
Commit
8547b21c
authored
Apr 21, 2022
by
Yuge Zhang
Committed by
GitHub
Apr 21, 2022
Browse files
Merge pull request #4760 from microsoft/dev-oneshot
[DO NOT SQUASH] One-shot as strategy
parents
58d205d3
2355bacb
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
495 additions
and
84 deletions
+495
-84
nni/retiarii/strategy/oneshot.py
nni/retiarii/strategy/oneshot.py
+22
-0
test/ut/retiarii/test_oneshot.py
test/ut/retiarii/test_oneshot.py
+225
-84
test/ut/retiarii/test_oneshot_supermodules.py
test/ut/retiarii/test_oneshot_supermodules.py
+248
-0
No files found.
nni/retiarii/strategy/oneshot.py
0 → 100644
View file @
8547b21c
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from
.base
import
BaseStrategy
try
:
from
nni.retiarii.oneshot.pytorch.strategy
import
(
# pylint: disable=unused-import
DARTS
,
GumbelDARTS
,
Proxyless
,
ENAS
,
RandomOneShot
)
except
ImportError
as
import_err
:
_import_err
=
import_err
class
ImportFailedStrategy
(
BaseStrategy
):
def
run
(
self
,
base_model
,
applied_mutators
):
raise
_import_err
# otherwise typing check will pointing to the wrong location
globals
()[
'DARTS'
]
=
ImportFailedStrategy
globals
()[
'GumbelDARTS'
]
=
ImportFailedStrategy
globals
()[
'Proxyless'
]
=
ImportFailedStrategy
globals
()[
'ENAS'
]
=
ImportFailedStrategy
globals
()[
'RandomOneShot'
]
=
ImportFailedStrategy
test/ut/retiarii/test_oneshot.py
View file @
8547b21c
import
argparse
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
pytorch_lightning
as
pl
import
pytest
from
torchvision
import
transforms
from
torchvision.datasets
import
MNIST
from
torch.utils.data
.sampler
import
RandomSampler
from
torch.utils.data
import
Dataset
,
RandomSampler
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
from
nni.retiarii.oneshot.pytorch
import
(
ConcatenateTrainValDataLoader
,
DartsModule
,
EnasModule
,
SNASModule
,
InterleavedTrainValDataLoader
,
ProxylessModule
,
RandomSampleModule
)
import
nni.retiarii.nn.pytorch
as
nn
from
nni.retiarii
import
strategy
,
model_wrapper
,
basic_unit
from
nni.retiarii.experiment.pytorch
import
RetiariiExeConfig
,
RetiariiExperiment
from
nni.retiarii.evaluator.pytorch.lightning
import
Classification
,
Regression
,
DataLoader
from
nni.retiarii.nn.pytorch
import
LayerChoice
,
InputChoice
,
ValueChoice
class
DepthwiseSeparableConv
(
nn
.
Module
):
...
...
@@ -26,119 +24,262 @@ class DepthwiseSeparableConv(nn.Module):
return
self
.
pointwise
(
self
.
depthwise
(
x
))
class
Net
(
pl
.
LightningModule
):
def
__init__
(
self
):
@
model_wrapper
class
SimpleNet
(
nn
.
Module
):
def
__init__
(
self
,
value_choice
=
True
):
super
().
__init__
()
self
.
conv1
=
nn
.
Conv2d
(
1
,
32
,
3
,
1
)
self
.
conv2
=
LayerChoice
([
nn
.
Conv2d
(
32
,
64
,
3
,
1
),
DepthwiseSeparableConv
(
32
,
64
)
])
self
.
dropout1
=
nn
.
Dropout
(.
25
)
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
self
.
dropout_choice
=
InputChoice
(
2
,
1
)
self
.
fc
=
LayerChoice
([
nn
.
Sequential
(
nn
.
Linear
(
9216
,
64
),
nn
.
ReLU
(),
nn
.
Linear
(
64
,
10
),
),
nn
.
Sequential
(
nn
.
Linear
(
9216
,
128
),
nn
.
ReLU
(),
nn
.
Linear
(
128
,
10
),
),
nn
.
Sequential
(
nn
.
Linear
(
9216
,
256
),
nn
.
ReLU
(),
nn
.
Linear
(
256
,
10
),
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
dropout2
=
nn
.
Dropout
(
0.5
)
if
value_choice
:
hidden
=
nn
.
ValueChoice
([
32
,
64
,
128
])
else
:
hidden
=
64
self
.
fc1
=
nn
.
Linear
(
9216
,
hidden
)
self
.
fc2
=
nn
.
Linear
(
hidden
,
10
)
self
.
rpfc
=
nn
.
Linear
(
10
,
10
)
self
.
input_ch
=
InputChoice
(
2
,
1
)
def
forward
(
self
,
x
):
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x1
=
torch
.
flatten
(
self
.
dropout1
(
x
),
1
)
x2
=
torch
.
flatten
(
self
.
dropout2
(
x
),
1
)
x
=
self
.
dropout_choice
([
x1
,
x2
])
x
=
self
.
fc
(
x
)
x
=
self
.
rpfc
(
x
)
x
=
torch
.
flatten
(
self
.
dropout1
(
x
),
1
)
x
=
self
.
fc1
(
x
)
x
=
F
.
relu
(
x
)
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x1
=
self
.
rpfc
(
x
)
x
=
self
.
input_ch
([
x
,
x1
])
output
=
F
.
log_softmax
(
x
,
dim
=
1
)
return
output
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
prepare_model_data
():
base_model
=
Net
()
@
model_wrapper
class
MultiHeadAttentionNet
(
nn
.
Module
):
def
__init__
(
self
,
head_count
):
super
().
__init__
()
embed_dim
=
ValueChoice
(
candidates
=
[
32
,
64
])
self
.
linear1
=
nn
.
Linear
(
128
,
embed_dim
)
self
.
mhatt
=
nn
.
MultiheadAttention
(
embed_dim
,
head_count
)
self
.
linear2
=
nn
.
Linear
(
embed_dim
,
1
)
def
forward
(
self
,
batch
):
query
,
key
,
value
=
batch
q
,
k
,
v
=
self
.
linear1
(
query
),
self
.
linear1
(
key
),
self
.
linear1
(
value
)
output
,
_
=
self
.
mhatt
(
q
,
k
,
v
,
need_weights
=
False
)
y
=
self
.
linear2
(
output
)
return
F
.
relu
(
y
)
@
model_wrapper
class
ValueChoiceConvNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
ch1
=
ValueChoice
([
16
,
32
])
kernel
=
ValueChoice
([
3
,
5
])
self
.
conv1
=
nn
.
Conv2d
(
1
,
ch1
,
kernel
,
padding
=
kernel
//
2
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
ch1
)
self
.
conv2
=
nn
.
Conv2d
(
ch1
,
64
,
3
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
fc
=
nn
.
Linear
(
64
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
@
model_wrapper
class
RepeatNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
ch1
=
ValueChoice
([
16
,
32
])
kernel
=
ValueChoice
([
3
,
5
])
self
.
conv1
=
nn
.
Conv2d
(
1
,
ch1
,
kernel
,
padding
=
kernel
//
2
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
ch1
)
self
.
conv2
=
nn
.
Conv2d
(
ch1
,
64
,
3
,
padding
=
1
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
fc
=
nn
.
Linear
(
64
,
10
)
self
.
rpfc
=
nn
.
Repeat
(
nn
.
Linear
(
10
,
10
),
(
1
,
4
))
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
x
=
self
.
rpfc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
@
basic_unit
class
MyOp
(
nn
.
Module
):
def
__init__
(
self
,
some_ch
):
super
().
__init__
()
self
.
some_ch
=
some_ch
self
.
batch_norm
=
nn
.
BatchNorm2d
(
some_ch
)
def
forward
(
self
,
x
):
return
self
.
batch_norm
(
x
)
@
model_wrapper
class
CustomOpValueChoiceNet
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
ch1
=
ValueChoice
([
16
,
32
])
kernel
=
ValueChoice
([
3
,
5
])
self
.
conv1
=
nn
.
Conv2d
(
1
,
ch1
,
kernel
,
padding
=
kernel
//
2
)
self
.
batch_norm
=
MyOp
(
ch1
)
self
.
conv2
=
nn
.
Conv2d
(
ch1
,
64
,
3
,
padding
=
1
)
self
.
dropout1
=
LayerChoice
([
nn
.
Dropout
(.
25
),
nn
.
Dropout
(.
5
),
nn
.
Dropout
(.
75
)
])
self
.
fc
=
nn
.
Linear
(
64
,
10
)
def
forward
(
self
,
x
):
x
=
self
.
conv1
(
x
)
x
=
self
.
batch_norm
(
x
)
x
=
F
.
relu
(
x
)
x
=
F
.
max_pool2d
(
self
.
conv2
(
x
),
2
)
x
=
torch
.
mean
(
x
,
(
2
,
3
))
x
=
self
.
fc
(
x
)
return
F
.
log_softmax
(
x
,
dim
=
1
)
def
_mnist_net
(
type_
):
if
type_
==
'simple'
:
base_model
=
SimpleNet
(
False
)
elif
type_
==
'simple_value_choice'
:
base_model
=
SimpleNet
()
elif
type_
==
'value_choice'
:
base_model
=
ValueChoiceConvNet
()
elif
type_
==
'repeat'
:
base_model
=
RepeatNet
()
elif
type_
==
'custom_op'
:
base_model
=
CustomOpValueChoiceNet
()
else
:
raise
ValueError
(
f
'Unsupported type:
{
type_
}
'
)
transform
=
transforms
.
Compose
([
transforms
.
ToTensor
(),
transforms
.
Normalize
((
0.1307
,),
(
0.3081
,))])
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
10
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
10
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
train_dataset
=
MNIST
(
'data/mnist'
,
train
=
True
,
download
=
True
,
transform
=
transform
)
train_random_sampler
=
RandomSampler
(
train_dataset
,
True
,
int
(
len
(
train_dataset
)
/
20
))
train_loader
=
DataLoader
(
train_dataset
,
64
,
sampler
=
train_random_sampler
)
valid_dataset
=
MNIST
(
'data/mnist'
,
train
=
False
,
download
=
True
,
transform
=
transform
)
valid_random_sampler
=
RandomSampler
(
valid_dataset
,
True
,
int
(
len
(
valid_dataset
)
/
20
))
valid_loader
=
DataLoader
(
valid_dataset
,
64
,
sampler
=
valid_random_sampler
)
evaluator
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
max_epochs
=
1
)
return
base_model
,
evaluator
def
_multihead_attention_net
():
base_model
=
MultiHeadAttentionNet
(
1
)
class
AttentionRandDataset
(
Dataset
):
def
__init__
(
self
,
data_shape
,
gt_shape
,
len
)
->
None
:
super
().
__init__
()
self
.
datashape
=
data_shape
self
.
gtshape
=
gt_shape
self
.
len
=
len
def
__getitem__
(
self
,
index
):
q
=
torch
.
rand
(
self
.
datashape
)
k
=
torch
.
rand
(
self
.
datashape
)
v
=
torch
.
rand
(
self
.
datashape
)
gt
=
torch
.
rand
(
self
.
gtshape
)
return
(
q
,
k
,
v
),
gt
def
__len__
(
self
):
return
self
.
len
train_set
=
AttentionRandDataset
((
1
,
128
),
(
1
,
1
),
1000
)
val_set
=
AttentionRandDataset
((
1
,
128
),
(
1
,
1
),
500
)
train_loader
=
DataLoader
(
train_set
,
batch_size
=
32
)
val_loader
=
DataLoader
(
val_set
,
batch_size
=
32
)
evaluator
=
Regression
(
train_dataloader
=
train_loader
,
val_dataloaders
=
val_loader
,
max_epochs
=
1
)
return
base_model
,
evaluator
def
_test_strategy
(
strategy_
,
support_value_choice
=
True
):
to_test
=
[
# (model, evaluator), support_or_net
(
_mnist_net
(
'simple'
),
True
),
(
_mnist_net
(
'simple_value_choice'
),
support_value_choice
),
(
_mnist_net
(
'value_choice'
),
support_value_choice
),
(
_mnist_net
(
'repeat'
),
False
),
# no strategy supports repeat currently
(
_mnist_net
(
'custom_op'
),
False
),
# this is definitely a NO
(
_multihead_attention_net
(),
support_value_choice
),
]
for
(
base_model
,
evaluator
),
support_or_not
in
to_test
:
print
(
'Testing:'
,
type
(
strategy_
).
__name__
,
type
(
base_model
).
__name__
,
type
(
evaluator
).
__name__
,
support_or_not
)
experiment
=
RetiariiExperiment
(
base_model
,
evaluator
,
strategy
=
strategy_
)
trainer_kwargs
=
{
'max_epochs'
:
1
}
config
=
RetiariiExeConfig
()
config
.
execution_engine
=
'oneshot'
return
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
if
support_or_not
:
experiment
.
run
(
config
)
assert
isinstance
(
experiment
.
export_top_models
()[
0
],
dict
)
else
:
with
pytest
.
raises
(
TypeError
,
match
=
'not supported'
):
experiment
.
run
(
config
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_darts
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
darts_model
=
DartsModule
(
cls
.
module
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
darts_model
,
para_loader
)
_test_strategy
(
strategy
.
DARTS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_proxyless
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
proxyless_model
=
ProxylessModule
(
cls
.
module
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
proxyless_model
,
para_loader
)
_test_strategy
(
strategy
.
Proxyless
(),
False
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_enas
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
enas_model
=
EnasModule
(
cls
.
module
)
concat_loader
=
ConcatenateTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
enas_model
,
concat_loader
)
_test_strategy
(
strategy
.
ENAS
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_random
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
random_model
=
RandomSampleModule
(
cls
.
module
)
cls
.
trainer
.
fit
(
random_model
,
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
_test_strategy
(
strategy
.
RandomOneShot
())
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_snas
():
base_model
,
train_loader
,
valid_loader
,
trainer_kwargs
=
prepare_model_data
()
cls
=
Classification
(
train_dataloader
=
train_loader
,
val_dataloaders
=
valid_loader
,
**
trainer_kwargs
)
cls
.
module
.
set_model
(
base_model
)
proxyless_model
=
SNASModule
(
cls
.
module
,
1
,
use_temp_anneal
=
True
)
para_loader
=
InterleavedTrainValDataLoader
(
cls
.
train_dataloader
,
cls
.
val_dataloaders
)
cls
.
trainer
.
fit
(
proxyless_model
,
para_loader
)
@
pytest
.
mark
.
skipif
(
pl
.
__version__
<
'1.0'
,
reason
=
'Incompatible APIs'
)
def
test_gumbel_darts
():
_test_strategy
(
strategy
.
GumbelDARTS
())
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--exp'
,
type
=
str
,
default
=
'all'
,
metavar
=
'E'
,
help
=
'exp to run, default = all'
)
help
=
'exp
eriment
to run, default = all'
)
args
=
parser
.
parse_args
()
if
args
.
exp
==
'all'
:
...
...
@@ -146,6 +287,6 @@ if __name__ == '__main__':
test_proxyless
()
test_enas
()
test_random
()
test_
sna
s
()
test_
gumbel_dart
s
()
else
:
globals
()[
f
'test_
{
args
.
exp
}
'
]()
test/ut/retiarii/test_oneshot_supermodules.py
0 → 100644
View file @
8547b21c
import
pytest
import
numpy
as
np
import
torch
import
torch.nn
as
nn
from
nni.retiarii.nn.pytorch
import
ValueChoice
,
Conv2d
,
BatchNorm2d
,
Linear
,
MultiheadAttention
from
nni.retiarii.oneshot.pytorch.supermodule.differentiable
import
(
MixedOpDifferentiablePolicy
,
DifferentiableMixedLayer
,
DifferentiableMixedInput
,
GumbelSoftmax
)
from
nni.retiarii.oneshot.pytorch.supermodule.sampling
import
(
MixedOpPathSamplingPolicy
,
PathSamplingLayer
,
PathSamplingInput
)
from
nni.retiarii.oneshot.pytorch.supermodule.operation
import
MixedConv2d
,
NATIVE_MIXED_OPERATIONS
from
nni.retiarii.oneshot.pytorch.supermodule.proxyless
import
ProxylessMixedLayer
,
ProxylessMixedInput
from
nni.retiarii.oneshot.pytorch.supermodule._operation_utils
import
Slicable
as
S
,
MaybeWeighted
as
W
from
nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils
import
*
def
test_slice
():
weight
=
np
.
ones
((
3
,
7
,
24
,
23
))
assert
S
(
weight
)[:,
1
:
3
,
:,
9
:
13
].
shape
==
(
3
,
2
,
24
,
4
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
,
:,
9
:
13
].
shape
==
(
3
,
6
,
24
,
4
)
assert
S
(
weight
)[:,
1
:
W
(
3
)
*
2
+
1
].
shape
==
(
3
,
6
,
24
,
23
)
# no effect
assert
S
(
weight
)[:]
is
weight
# list
assert
S
(
weight
)[[
slice
(
1
),
slice
(
2
,
3
)]].
shape
==
(
2
,
7
,
24
,
23
)
assert
S
(
weight
)[[
slice
(
1
),
slice
(
2
,
W
(
2
)
+
1
)],
W
(
2
):].
shape
==
(
2
,
5
,
24
,
23
)
# weighted
weight
=
S
(
weight
)[:
W
({
1
:
0.5
,
2
:
0.3
,
3
:
0.2
})]
weight
=
weight
[:,
0
,
0
,
0
]
assert
weight
[
0
]
==
1
and
weight
[
1
]
==
0.5
and
weight
[
2
]
==
0.2
weight
=
np
.
ones
((
3
,
6
,
6
))
value
=
W
({
1
:
0.5
,
3
:
0.5
})
weight
=
S
(
weight
)[:,
3
-
value
:
3
+
value
,
3
-
value
:
3
+
value
]
for
i
in
range
(
0
,
6
):
for
j
in
range
(
0
,
6
):
if
2
<=
i
<=
3
and
2
<=
j
<=
3
:
assert
weight
[
0
,
i
,
j
]
==
1
else
:
assert
weight
[
1
,
i
,
j
]
==
0.5
# weighted + list
value
=
W
({
1
:
0.5
,
3
:
0.5
})
weight
=
np
.
ones
((
8
,
4
))
weight
=
S
(
weight
)[[
slice
(
value
),
slice
(
4
,
value
+
4
)]]
assert
weight
.
sum
(
1
).
tolist
()
==
[
4
,
2
,
2
,
0
,
4
,
2
,
2
,
0
]
with
pytest
.
raises
(
ValueError
,
match
=
'one distinct'
):
# has to be exactly the same instance, equal is not enough
weight
=
S
(
weight
)[:
W
({
1
:
0.5
}),
:
W
({
1
:
0.5
})]
def
test_valuechoice_utils
():
chosen
=
{
"exp"
:
3
,
"add"
:
1
}
vc0
=
ValueChoice
([
3
,
4
,
6
],
label
=
'exp'
)
*
2
+
ValueChoice
([
0
,
1
],
label
=
'add'
)
assert
evaluate_value_choice_with_dict
(
vc0
,
chosen
)
==
7
vc
=
vc0
+
ValueChoice
([
3
,
4
,
6
],
label
=
'exp'
)
assert
evaluate_value_choice_with_dict
(
vc
,
chosen
)
==
10
assert
list
(
dedup_inner_choices
([
vc0
,
vc
]).
keys
())
==
[
'exp'
,
'add'
]
assert
traverse_all_options
(
vc
)
==
[
9
,
10
,
12
,
13
,
18
,
19
]
weights
=
dict
(
traverse_all_options
(
vc
,
weights
=
{
'exp'
:
[
0.5
,
0.3
,
0.2
],
'add'
:
[
0.4
,
0.6
]}))
ans
=
dict
([(
9
,
0.2
),
(
10
,
0.3
),
(
12
,
0.12
),
(
13
,
0.18
),
(
18
,
0.08
),
(
19
,
0.12
)])
assert
len
(
weights
)
==
len
(
ans
)
for
value
,
weight
in
ans
.
items
():
assert
abs
(
weight
-
weights
[
value
])
<
1e-6
def
test_pathsampling_valuechoice
():
orig_conv
=
Conv2d
(
3
,
ValueChoice
([
3
,
5
,
7
],
label
=
'123'
),
kernel_size
=
3
)
conv
=
MixedConv2d
.
mutate
(
orig_conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
conv
.
resample
(
memo
=
{
'123'
:
5
})
assert
conv
(
torch
.
zeros
((
1
,
3
,
5
,
5
))).
size
(
1
)
==
5
conv
.
resample
(
memo
=
{
'123'
:
7
})
assert
conv
(
torch
.
zeros
((
1
,
3
,
5
,
5
))).
size
(
1
)
==
7
assert
conv
.
export
({})[
'123'
]
in
[
3
,
5
,
7
]
def
test_differentiable_valuechoice
():
orig_conv
=
Conv2d
(
3
,
ValueChoice
([
3
,
5
,
7
],
label
=
'456'
),
kernel_size
=
ValueChoice
(
[
3
,
5
,
7
],
label
=
'123'
),
padding
=
ValueChoice
([
3
,
5
,
7
],
label
=
'123'
)
//
2
)
conv
=
MixedConv2d
.
mutate
(
orig_conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
})
assert
conv
(
torch
.
zeros
((
1
,
3
,
7
,
7
))).
size
(
2
)
==
7
assert
set
(
conv
.
export
({}).
keys
())
==
{
'123'
,
'456'
}
def
_mixed_operation_sampling_sanity_check
(
operation
,
memo
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
break
mutate_op
.
resample
(
memo
=
memo
)
return
mutate_op
(
*
input
)
def
_mixed_operation_differentiable_sanity_check
(
operation
,
*
input
):
for
native_op
in
NATIVE_MIXED_OPERATIONS
:
if
native_op
.
bound_type
==
type
(
operation
):
mutate_op
=
native_op
.
mutate
(
operation
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpDifferentiablePolicy
})
break
return
mutate_op
(
*
input
)
def
test_mixed_linear
():
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
9
},
torch
.
randn
(
2
,
9
))
_mixed_operation_differentiable_sanity_check
(
linear
,
torch
.
randn
(
2
,
9
))
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]),
bias
=
False
)
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
with
pytest
.
raises
(
TypeError
):
linear
=
Linear
(
ValueChoice
([
3
,
6
,
9
],
label
=
'shared'
),
ValueChoice
([
2
,
4
,
8
]),
bias
=
ValueChoice
([
False
,
True
]))
_mixed_operation_sampling_sanity_check
(
linear
,
{
'shared'
:
3
},
torch
.
randn
(
2
,
3
))
def
test_mixed_conv2d
():
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out'
)
*
2
,
1
)
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
3
,
'out'
:
4
},
torch
.
randn
(
2
,
3
,
9
,
9
)).
size
(
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
conv
,
torch
.
randn
(
2
,
9
,
3
,
3
))
# stride
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
2
,
4
,
8
],
label
=
'out'
),
1
,
stride
=
ValueChoice
([
1
,
2
],
label
=
'stride'
))
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
3
,
'stride'
:
2
},
torch
.
randn
(
2
,
3
,
10
,
10
)).
size
(
2
)
==
5
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
3
,
'stride'
:
1
},
torch
.
randn
(
2
,
3
,
10
,
10
)).
size
(
2
)
==
10
# groups, dw conv
conv
=
Conv2d
(
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
),
1
,
groups
=
ValueChoice
([
3
,
6
,
9
],
label
=
'in'
))
assert
_mixed_operation_sampling_sanity_check
(
conv
,
{
'in'
:
6
},
torch
.
randn
(
2
,
6
,
10
,
10
)).
size
()
==
torch
.
Size
([
2
,
6
,
10
,
10
])
# make sure kernel is sliced correctly
conv
=
Conv2d
(
1
,
1
,
ValueChoice
([
1
,
3
],
label
=
'k'
),
bias
=
False
)
conv
=
MixedConv2d
.
mutate
(
conv
,
'dummy'
,
{},
{
'mixed_op_sampling'
:
MixedOpPathSamplingPolicy
})
with
torch
.
no_grad
():
conv
.
weight
.
zero_
()
# only center is 1, must pick center to pass this test
conv
.
weight
[
0
,
0
,
1
,
1
]
=
1
conv
.
resample
({
'k'
:
1
})
assert
conv
(
torch
.
ones
((
1
,
1
,
3
,
3
))).
sum
().
item
()
==
9
def
test_mixed_batchnorm2d
():
bn
=
BatchNorm2d
(
ValueChoice
([
32
,
64
],
label
=
'dim'
))
assert
_mixed_operation_sampling_sanity_check
(
bn
,
{
'dim'
:
32
},
torch
.
randn
(
2
,
32
,
3
,
3
)).
size
(
1
)
==
32
assert
_mixed_operation_sampling_sanity_check
(
bn
,
{
'dim'
:
64
},
torch
.
randn
(
2
,
64
,
3
,
3
)).
size
(
1
)
==
64
_mixed_operation_differentiable_sanity_check
(
bn
,
torch
.
randn
(
2
,
64
,
3
,
3
))
def
test_mixed_mhattn
():
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
)
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
},
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
))[
0
].
size
(
-
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
))
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
ValueChoice
([
2
,
3
,
4
],
label
=
'heads'
))
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'heads'
:
2
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
with
pytest
.
raises
(
AssertionError
,
match
=
'divisible'
):
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'heads'
:
3
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
,
kdim
=
ValueChoice
([
5
,
7
],
label
=
'kdim'
))
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'kdim'
:
7
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
7
),
torch
.
randn
(
7
,
2
,
4
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
,
'kdim'
:
5
},
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
5
),
torch
.
randn
(
7
,
2
,
8
))[
0
].
size
(
-
1
)
==
8
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
4
,
vdim
=
ValueChoice
([
5
,
8
],
label
=
'vdim'
))
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'vdim'
:
8
},
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
4
),
torch
.
randn
(
7
,
2
,
8
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
,
'vdim'
:
5
},
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
8
),
torch
.
randn
(
7
,
2
,
5
))[
0
].
size
(
-
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
),
torch
.
randn
(
5
,
3
,
8
))
@
pytest
.
mark
.
skipif
(
torch
.
__version__
.
startswith
(
'1.7'
),
reason
=
'batch_first is not supported for legacy PyTorch'
)
def
test_mixed_mhattn_batch_first
():
# batch_first is not supported for legacy pytorch versions
# mark 1.7 because 1.7 is used on legacy pipeline
mhattn
=
MultiheadAttention
(
ValueChoice
([
4
,
8
],
label
=
'emb'
),
2
,
kdim
=
(
ValueChoice
([
3
,
7
],
label
=
'kdim'
)),
vdim
=
ValueChoice
([
5
,
8
],
label
=
'vdim'
),
bias
=
False
,
add_bias_kv
=
True
,
batch_first
=
True
)
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
4
,
'kdim'
:
7
,
'vdim'
:
8
},
torch
.
randn
(
2
,
7
,
4
),
torch
.
randn
(
2
,
7
,
7
),
torch
.
randn
(
2
,
7
,
8
))[
0
].
size
(
-
1
)
==
4
assert
_mixed_operation_sampling_sanity_check
(
mhattn
,
{
'emb'
:
8
,
'kdim'
:
3
,
'vdim'
:
5
},
torch
.
randn
(
2
,
7
,
8
),
torch
.
randn
(
2
,
7
,
3
),
torch
.
randn
(
2
,
7
,
5
))[
0
].
size
(
-
1
)
==
8
_mixed_operation_differentiable_sanity_check
(
mhattn
,
torch
.
randn
(
1
,
7
,
8
),
torch
.
randn
(
1
,
7
,
7
),
torch
.
randn
(
1
,
7
,
8
))
def
test_pathsampling_layer_input
():
op
=
PathSamplingLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
label
=
'ccc'
)
with
pytest
.
raises
(
RuntimeError
,
match
=
'sample'
):
op
(
torch
.
randn
(
4
,
2
))
op
.
resample
({})
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
search_space_spec
()[
'ccc'
].
values
==
[
'a'
,
'b'
]
assert
op
.
export
({})[
'ccc'
]
in
[
'a'
,
'b'
]
input
=
PathSamplingInput
(
5
,
2
,
'concat'
,
'ddd'
)
sample
=
input
.
resample
({})
assert
'ddd'
in
sample
assert
len
(
sample
[
'ddd'
])
==
2
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
4
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
def
test_differentiable_layer_input
():
op
=
DifferentiableMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
len
(
list
(
op
.
parameters
()))
==
3
input
=
DifferentiableMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
(
-
1
)
==
2
assert
len
(
input
.
export
({})[
'ddd'
])
==
2
def
test_proxyless_layer_input
():
op
=
ProxylessMixedLayer
([(
'a'
,
Linear
(
2
,
3
,
bias
=
False
)),
(
'b'
,
Linear
(
2
,
3
,
bias
=
True
))],
nn
.
Parameter
(
torch
.
randn
(
2
)),
nn
.
Softmax
(
-
1
),
'eee'
)
assert
op
.
resample
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
op
(
torch
.
randn
(
4
,
2
)).
size
(
-
1
)
==
3
assert
op
.
export
({})[
'eee'
]
in
[
'a'
,
'b'
]
assert
len
(
list
(
op
.
parameters
()))
==
3
input
=
ProxylessMixedInput
(
5
,
2
,
nn
.
Parameter
(
torch
.
zeros
(
5
)),
GumbelSoftmax
(
-
1
),
'ddd'
)
assert
input
.
resample
({})[
'ddd'
]
in
list
(
range
(
5
))
assert
input
([
torch
.
randn
(
4
,
2
)
for
_
in
range
(
5
)]).
size
()
==
torch
.
Size
([
4
,
2
])
assert
input
.
export
({})[
'ddd'
]
in
list
(
range
(
5
))
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment