Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
75e5d5b5
"git@developer.sourcefind.cn:chenpangpang/diffusers.git" did not exist on "8d6487f3cbe89bb6e32f82fc9f04df6ce001ef24"
Unverified
Commit
75e5d5b5
authored
Aug 05, 2022
by
J-shang
Committed by
GitHub
Aug 05, 2022
Browse files
rm useless file (#5044)
parent
33cdb5b6
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
13 additions
and
327 deletions
+13
-327
nni/nas/oneshot/pytorch/base_lightning.py
nni/nas/oneshot/pytorch/base_lightning.py
+4
-4
test/algo/nas/test_lightning_trainer.py
test/algo/nas/test_lightning_trainer.py
+5
-5
test/training_service/config/examples/classic-nas-pytorch-v2.yml
...aining_service/config/examples/classic-nas-pytorch-v2.yml
+0
-18
test/training_service/config/examples/classic-nas-pytorch.yml
.../training_service/config/examples/classic-nas-pytorch.yml
+0
-21
test/training_service/config/examples/classic-nas-tf2.yml
test/training_service/config/examples/classic-nas-tf2.yml
+0
-21
test/training_service/config/integration_tests.yml
test/training_service/config/integration_tests.yml
+2
-16
test/training_service/config/integration_tests_config_v2.yml
test/training_service/config/integration_tests_config_v2.yml
+2
-16
test/training_service/config/integration_tests_tf2.yml
test/training_service/config/integration_tests_tf2.yml
+0
-14
test/training_service/config/tuners/regularized_evolution_tuner-v2.yml
..._service/config/tuners/regularized_evolution_tuner-v2.yml
+0
-14
test/training_service/config/tuners/regularized_evolution_tuner.yml
...ing_service/config/tuners/regularized_evolution_tuner.yml
+0
-20
test/ut/sdk/models/pytorch_models/__init__.py
test/ut/sdk/models/pytorch_models/__init__.py
+0
-4
test/ut/sdk/models/pytorch_models/mutable_scope.py
test/ut/sdk/models/pytorch_models/mutable_scope.py
+0
-95
test/ut/sdk/models/pytorch_models/naive.py
test/ut/sdk/models/pytorch_models/naive.py
+0
-45
test/ut/sdk/models/pytorch_models/nested.py
test/ut/sdk/models/pytorch_models/nested.py
+0
-34
No files found.
nni/nas/oneshot/pytorch/base_lightning.py
View file @
75e5d5b5
...
@@ -374,11 +374,11 @@ class BaseOneShotLightningModule(pl.LightningModule):
...
@@ -374,11 +374,11 @@ class BaseOneShotLightningModule(pl.LightningModule):
def
on_fit_end
(
self
):
def
on_fit_end
(
self
):
return
self
.
model
.
on_fit_end
()
return
self
.
model
.
on_fit_end
()
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
unused
=
0
):
def
on_train_batch_start
(
self
,
batch
,
batch_idx
,
*
args
,
**
kwargs
):
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
unused
)
return
self
.
model
.
on_train_batch_start
(
batch
,
batch_idx
,
*
args
,
**
kwargs
)
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
unused
=
0
):
def
on_train_batch_end
(
self
,
outputs
,
batch
,
batch_idx
,
*
args
,
**
kwargs
):
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
unused
)
return
self
.
model
.
on_train_batch_end
(
outputs
,
batch
,
batch_idx
,
*
args
,
**
kwargs
)
# Deprecated hooks in pytorch-lightning
# Deprecated hooks in pytorch-lightning
def
on_epoch_start
(
self
):
def
on_epoch_start
(
self
):
...
...
test/algo/nas/test_lightning_trainer.py
View file @
75e5d5b5
...
@@ -16,9 +16,9 @@ from torchvision.datasets import MNIST
...
@@ -16,9 +16,9 @@ from torchvision.datasets import MNIST
debug
=
False
debug
=
False
progress_bar
_refresh_rate
=
0
enable_
progress_bar
=
False
if
debug
:
if
debug
:
progress_bar
_refresh_rate
=
1
enable_
progress_bar
=
True
class
MNISTModel
(
nn
.
Module
):
class
MNISTModel
(
nn
.
Module
):
...
@@ -96,7 +96,7 @@ def test_mnist():
...
@@ -96,7 +96,7 @@ def test_mnist():
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
lightning
=
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
max_epochs
=
2
,
limit_train_batches
=
0.25
,
# for faster training
progress_bar
_refresh_rate
=
progress_bar_refresh_rate
)
enable_
progress_bar
=
enable_progress_bar
)
lightning
.
_execute
(
MNISTModel
)
lightning
.
_execute
(
MNISTModel
)
assert
_get_final_result
()
>
0.7
assert
_get_final_result
()
>
0.7
_reset
()
_reset
()
...
@@ -113,7 +113,7 @@ def test_diabetes():
...
@@ -113,7 +113,7 @@ def test_diabetes():
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
20
),
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
20
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
20
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
20
),
max_epochs
=
100
,
max_epochs
=
100
,
progress_bar
_refresh_rate
=
progress_bar_refresh_rate
)
enable_
progress_bar
=
enable_progress_bar
)
lightning
.
_execute
(
FCNet
(
train_dataset
.
x
.
shape
[
1
],
1
))
lightning
.
_execute
(
FCNet
(
train_dataset
.
x
.
shape
[
1
],
1
))
assert
_get_final_result
()
<
2e4
assert
_get_final_result
()
<
2e4
_reset
()
_reset
()
...
@@ -134,7 +134,7 @@ def test_fit_api():
...
@@ -134,7 +134,7 @@ def test_fit_api():
def
lightning
():
return
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
def
lightning
():
return
pl
.
Classification
(
train_dataloader
=
pl
.
DataLoader
(
train_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
val_dataloaders
=
pl
.
DataLoader
(
test_dataset
,
batch_size
=
100
),
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
max_epochs
=
1
,
limit_train_batches
=
0.1
,
# for faster training
progress_bar
_refresh_rate
=
progress_bar_refresh_rate
)
enable_
progress_bar
=
enable_progress_bar
)
# Lightning will have some cache in models / trainers,
# Lightning will have some cache in models / trainers,
# which is problematic if we call fit multiple times.
# which is problematic if we call fit multiple times.
lightning
().
fit
(
lambda
:
MNISTModel
())
lightning
().
fit
(
lambda
:
MNISTModel
())
...
...
test/training_service/config/examples/classic-nas-pytorch-v2.yml
deleted
100644 → 0
View file @
33cdb5b6
experimentName
:
default_test
searchSpaceFile
:
ni-nas-search-space.json
trialCommand
:
python3 main.py --epochs 1 --batches
1
trialCodeDirectory
:
../../../../examples/nas/legacy/classic_nas
trialGpuNumber
:
0
trialConcurrency
:
1
maxExperimentDuration
:
15m
maxTrialNumber
:
1
tuner
:
name
:
PPOTuner
classArgs
:
optimize_mode
:
maximize
trainingService
:
platform
:
local
assessor
:
name
:
Medianstop
classArgs
:
optimize_mode
:
maximize
test/training_service/config/examples/classic-nas-pytorch.yml
deleted
100644 → 0
View file @
33cdb5b6
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
10m
maxTrialNum
:
1
trialConcurrency
:
1
searchSpacePath
:
nni-nas-search-space.json
tuner
:
builtinTunerName
:
PPOTuner
classArgs
:
optimize_mode
:
maximize
trial
:
command
:
python3 mnist.py --epochs
1
codeDir
:
../../../../examples/nas/legacy/classic_nas
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
false
multiThread
:
false
trainingServicePlatform
:
local
\ No newline at end of file
test/training_service/config/examples/classic-nas-tf2.yml
deleted
100644 → 0
View file @
33cdb5b6
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
10m
maxTrialNum
:
1
trialConcurrency
:
1
searchSpacePath
:
nni-nas-search-space-tf2.json
tuner
:
builtinTunerName
:
PPOTuner
classArgs
:
optimize_mode
:
maximize
trial
:
command
:
python3 train.py --epochs
1
codeDir
:
../../../../examples/nas/legacy/classic_nas-tf
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
false
multiThread
:
false
trainingServicePlatform
:
local
\ No newline at end of file
test/training_service/config/integration_tests.yml
View file @
75e5d5b5
...
@@ -114,20 +114,6 @@ testCases:
...
@@ -114,20 +114,6 @@ testCases:
#- name: nested-ss
#- name: nested-ss
# configFile: test/training_service/config/examples/mnist-nested-search-space.yml
# configFile: test/training_service/config/examples/mnist-nested-search-space.yml
-
name
:
classic-nas-gen-ss
configFile
:
test/training_service/config/examples/classic-nas-pytorch.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/legacy/classic_nas --file=training_service/config/examples/nni-nas-search-space.json
stopCommand
:
experimentStatusCheck
:
False
trainingService
:
local
-
name
:
classic-nas-pytorch
configFile
:
test/training_service/config/examples/classic-nas-pytorch.yml
# remove search space file
stopCommand
:
nnictl stop
onExitCommand
:
python3 -c "import os; os.remove('training_service/config/examples/nni-nas-search-space.json')"
trainingService
:
local
#########################################################################
#########################################################################
# nni features test
# nni features test
#########################################################################
#########################################################################
...
@@ -247,8 +233,8 @@ testCases:
...
@@ -247,8 +233,8 @@ testCases:
#- name: tuner-metis
#- name: tuner-metis
# configFile: test/training_service/config/tuners/metis.yml
# configFile: test/training_service/config/tuners/metis.yml
-
name
:
tuner-regularized_evolution
#
- name: tuner-regularized_evolution
configFile
:
test/training_service/config/tuners/regularized_evolution_tuner.yml
#
configFile: test/training_service/config/tuners/regularized_evolution_tuner.yml
#########################################################################
#########################################################################
# nni customized-tuners test
# nni customized-tuners test
...
...
test/training_service/config/integration_tests_config_v2.yml
View file @
75e5d5b5
...
@@ -38,20 +38,6 @@ testCases:
...
@@ -38,20 +38,6 @@ testCases:
configFile
:
test/training_service/config/examples/cifar10-pytorch-adl.yml
configFile
:
test/training_service/config/examples/cifar10-pytorch-adl.yml
trainingService
:
adl
trainingService
:
adl
-
name
:
classic-nas-gen-ss
configFile
:
test/training_service/config/examples/classic-nas-pytorch-v2.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 mnist.py --epochs 1" --trial_dir=../examples/nas/legacy/classic_nas --file=training_service/config/examples/nni-nas-search-space.json
stopCommand
:
experimentStatusCheck
:
False
trainingService
:
local
-
name
:
classic-nas-pytorch
configFile
:
test/training_service/config/examples/classic-nas-pytorch-v2.yml
# remove search space file
stopCommand
:
nnictl stop
onExitCommand
:
python3 -c "import os; os.remove('training_service/examples/nni-nas-search-space.json')"
trainingService
:
local
#########################################################################
#########################################################################
# nni features test
# nni features test
#########################################################################
#########################################################################
...
@@ -124,8 +110,8 @@ testCases:
...
@@ -124,8 +110,8 @@ testCases:
#########################################################################
#########################################################################
# nni tuners test
# nni tuners test
#########################################################################
#########################################################################
-
name
:
tuner-regularized_evolution
#
- name: tuner-regularized_evolution
configFile
:
test/training_service/config/tuners/regularized_evolution_tuner-v2.yml
#
configFile: test/training_service/config/tuners/regularized_evolution_tuner-v2.yml
#########################################################################
#########################################################################
# nni customized-tuners test
# nni customized-tuners test
...
...
test/training_service/config/integration_tests_tf2.yml
View file @
75e5d5b5
...
@@ -56,20 +56,6 @@ testCases:
...
@@ -56,20 +56,6 @@ testCases:
configFile
:
test/training_service/config/examples/cifar10-pytorch-adl.yml
configFile
:
test/training_service/config/examples/cifar10-pytorch-adl.yml
trainingService
:
adl
trainingService
:
adl
-
name
:
classic-nas-gen-ss
configFile
:
test/training_service/config/examples/classic-nas-tf2.yml
launchCommand
:
nnictl ss_gen --trial_command="python3 train.py --epochs 1" --trial_dir=../examples/nas/legacy/classic_nas-tf --file=training_service/config/examples/nni-nas-search-space-tf2.json
stopCommand
:
experimentStatusCheck
:
False
trainingService
:
local
-
name
:
classic-nas-tensorflow2
configFile
:
test/training_service/config/examples/classic-nas-tf2.yml
# remove search space file
stopCommand
:
nnictl stop
onExitCommand
:
python3 -c 'import os; os.remove("training_service/config/examples/nni-nas-search-space-tf2.json")'
trainingService
:
local
#########################################################################
#########################################################################
# nni features test
# nni features test
#########################################################################
#########################################################################
...
...
test/training_service/config/tuners/regularized_evolution_tuner-v2.yml
deleted
100644 → 0
View file @
33cdb5b6
experimentName
:
default_test
searchSpaceFile
:
seach_space_classic_nas.json
trialCommand
:
python3 mnist.py --epochs
1
trialCodeDirectory
:
../../../../examples/nas/legacy/classic_nas
trialGpuNumber
:
0
trialConcurrency
:
1
maxExperimentDuration
:
15m
maxTrialNumber
:
1
tuner
:
name
:
RegularizedEvolutionTuner
classArgs
:
optimize_mode
:
maximize
trainingService
:
platform
:
local
test/training_service/config/tuners/regularized_evolution_tuner.yml
deleted
100644 → 0
View file @
33cdb5b6
authorName
:
nni
experimentName
:
default_test
maxExecDuration
:
10m
maxTrialNum
:
1
trialConcurrency
:
1
searchSpacePath
:
seach_space_classic_nas.json
tuner
:
builtinTunerName
:
RegularizedEvolutionTuner
classArgs
:
optimize_mode
:
maximize
trial
:
codeDir
:
../../../../examples/nas/legacy/classic_nas
command
:
python3 mnist.py --epochs
1
gpuNum
:
0
useAnnotation
:
false
multiPhase
:
false
multiThread
:
false
trainingServicePlatform
:
local
test/ut/sdk/models/pytorch_models/__init__.py
View file @
75e5d5b5
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
from
.mutable_scope
import
SpaceWithMutableScope
from
.naive
import
NaiveSearchSpace
from
.nested
import
NestedSpace
test/ut/sdk/models/pytorch_models/mutable_scope.py
deleted
100644 → 0
View file @
33cdb5b6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
,
MutableScope
class
Cell
(
MutableScope
):
def
__init__
(
self
,
cell_name
,
prev_labels
,
channels
):
super
().
__init__
(
cell_name
)
self
.
input_choice
=
InputChoice
(
choose_from
=
prev_labels
,
n_chosen
=
1
,
return_mask
=
True
,
key
=
cell_name
+
"_input"
)
self
.
op_choice
=
LayerChoice
([
nn
.
Conv2d
(
channels
,
channels
,
3
,
padding
=
1
),
nn
.
Conv2d
(
channels
,
channels
,
5
,
padding
=
2
),
nn
.
MaxPool2d
(
3
,
stride
=
1
,
padding
=
1
),
nn
.
AvgPool2d
(
3
,
stride
=
1
,
padding
=
1
),
nn
.
Identity
()
],
key
=
cell_name
+
"_op"
)
def
forward
(
self
,
prev_layers
):
chosen_input
,
chosen_mask
=
self
.
input_choice
(
prev_layers
)
cell_out
=
self
.
op_choice
(
chosen_input
)
return
cell_out
,
chosen_mask
class
Node
(
MutableScope
):
def
__init__
(
self
,
node_name
,
prev_node_names
,
channels
):
super
().
__init__
(
node_name
)
self
.
cell_x
=
Cell
(
node_name
+
"_x"
,
prev_node_names
,
channels
)
self
.
cell_y
=
Cell
(
node_name
+
"_y"
,
prev_node_names
,
channels
)
def
forward
(
self
,
prev_layers
):
out_x
,
mask_x
=
self
.
cell_x
(
prev_layers
)
out_y
,
mask_y
=
self
.
cell_y
(
prev_layers
)
return
out_x
+
out_y
,
mask_x
|
mask_y
class
Layer
(
nn
.
Module
):
def
__init__
(
self
,
num_nodes
,
channels
):
super
().
__init__
()
self
.
num_nodes
=
num_nodes
self
.
nodes
=
nn
.
ModuleList
()
node_labels
=
[
InputChoice
.
NO_KEY
,
InputChoice
.
NO_KEY
]
for
i
in
range
(
num_nodes
):
node_labels
.
append
(
"node_{}"
.
format
(
i
))
self
.
nodes
.
append
(
Node
(
node_labels
[
-
1
],
node_labels
[:
-
1
],
channels
))
self
.
final_conv_w
=
nn
.
Parameter
(
torch
.
zeros
(
channels
,
self
.
num_nodes
+
2
,
channels
,
1
,
1
),
requires_grad
=
True
)
self
.
bn
=
nn
.
BatchNorm2d
(
channels
,
affine
=
False
)
def
forward
(
self
,
pprev
,
prev
):
prev_nodes_out
=
[
pprev
,
prev
]
nodes_used_mask
=
torch
.
zeros
(
self
.
num_nodes
+
2
,
dtype
=
torch
.
bool
,
device
=
prev
.
device
)
for
i
in
range
(
self
.
num_nodes
):
node_out
,
mask
=
self
.
nodes
[
i
](
prev_nodes_out
)
nodes_used_mask
[:
mask
.
size
(
0
)]
|=
mask
.
to
(
prev
.
device
)
# NOTE: which device should we put mask on?
prev_nodes_out
.
append
(
node_out
)
unused_nodes
=
torch
.
cat
([
out
for
used
,
out
in
zip
(
nodes_used_mask
,
prev_nodes_out
)
if
not
used
],
1
)
unused_nodes
=
F
.
relu
(
unused_nodes
)
conv_weight
=
self
.
final_conv_w
[:,
~
nodes_used_mask
,
:,
:,
:]
conv_weight
=
conv_weight
.
view
(
conv_weight
.
size
(
0
),
-
1
,
1
,
1
)
out
=
F
.
conv2d
(
unused_nodes
,
conv_weight
)
return
prev
,
self
.
bn
(
out
)
class
SpaceWithMutableScope
(
nn
.
Module
):
def
__init__
(
self
,
test_case
,
num_layers
=
4
,
num_nodes
=
5
,
channels
=
16
,
in_channels
=
3
,
num_classes
=
10
):
super
().
__init__
()
self
.
test_case
=
test_case
self
.
num_layers
=
num_layers
self
.
stem
=
nn
.
Sequential
(
nn
.
Conv2d
(
in_channels
,
channels
,
3
,
1
,
1
,
bias
=
False
),
nn
.
BatchNorm2d
(
channels
)
)
self
.
layers
=
nn
.
ModuleList
()
for
_
in
range
(
self
.
num_layers
+
2
):
self
.
layers
.
append
(
Layer
(
num_nodes
,
channels
))
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
dense
=
nn
.
Linear
(
channels
,
num_classes
)
def
forward
(
self
,
x
):
prev
=
cur
=
self
.
stem
(
x
)
for
layer
in
self
.
layers
:
prev
,
cur
=
layer
(
prev
,
cur
)
cur
=
self
.
gap
(
F
.
relu
(
cur
)).
view
(
x
.
size
(
0
),
-
1
)
return
self
.
dense
(
cur
)
test/ut/sdk/models/pytorch_models/naive.py
deleted
100644 → 0
View file @
33cdb5b6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
class
NaiveSearchSpace
(
nn
.
Module
):
def
__init__
(
self
,
test_case
):
super
().
__init__
()
self
.
test_case
=
test_case
self
.
conv1
=
LayerChoice
([
nn
.
Conv2d
(
3
,
6
,
3
,
padding
=
1
),
nn
.
Conv2d
(
3
,
6
,
5
,
padding
=
2
)])
self
.
pool
=
nn
.
MaxPool2d
(
2
,
2
)
self
.
conv2
=
LayerChoice
([
nn
.
Conv2d
(
6
,
16
,
3
,
padding
=
1
),
nn
.
Conv2d
(
6
,
16
,
5
,
padding
=
2
)],
return_mask
=
True
)
self
.
conv3
=
nn
.
Conv2d
(
16
,
16
,
1
)
self
.
skipconnect
=
InputChoice
(
n_candidates
=
1
)
self
.
skipconnect2
=
InputChoice
(
n_candidates
=
2
,
return_mask
=
True
)
self
.
bn
=
nn
.
BatchNorm2d
(
16
)
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
fc
=
nn
.
Linear
(
16
,
10
)
def
forward
(
self
,
x
):
bs
=
x
.
size
(
0
)
x
=
self
.
pool
(
F
.
relu
(
self
.
conv1
(
x
)))
x0
,
mask
=
self
.
conv2
(
x
)
self
.
test_case
.
assertEqual
(
mask
.
size
(),
torch
.
Size
([
2
]))
x1
=
F
.
relu
(
self
.
conv3
(
x0
))
_
,
mask
=
self
.
skipconnect2
([
x0
,
x1
])
x0
=
self
.
skipconnect
([
x0
])
if
x0
is
not
None
:
x1
+=
x0
x
=
self
.
pool
(
self
.
bn
(
x1
))
self
.
test_case
.
assertEqual
(
mask
.
size
(),
torch
.
Size
([
2
]))
x
=
self
.
gap
(
x
).
view
(
bs
,
-
1
)
x
=
self
.
fc
(
x
)
return
x
test/ut/sdk/models/pytorch_models/nested.py
deleted
100644 → 0
View file @
33cdb5b6
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
nni.nas.pytorch.mutables
import
LayerChoice
,
InputChoice
class
MutableOp
(
nn
.
Module
):
def
__init__
(
self
,
kernel_size
):
super
().
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
120
,
kernel_size
,
padding
=
kernel_size
//
2
)
self
.
nested_mutable
=
InputChoice
(
n_candidates
=
10
)
def
forward
(
self
,
x
):
return
self
.
conv
(
x
)
class
NestedSpace
(
nn
.
Module
):
# this doesn't pass tests
def
__init__
(
self
,
test_case
):
super
().
__init__
()
self
.
test_case
=
test_case
self
.
conv1
=
LayerChoice
([
MutableOp
(
3
),
MutableOp
(
5
)])
self
.
gap
=
nn
.
AdaptiveAvgPool2d
(
1
)
self
.
fc1
=
nn
.
Linear
(
120
,
10
)
def
forward
(
self
,
x
):
bs
=
x
.
size
(
0
)
x
=
F
.
relu
(
self
.
conv1
(
x
))
x
=
self
.
gap
(
x
).
view
(
bs
,
-
1
)
x
=
self
.
fc
(
x
)
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment