Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
bcda469f
Unverified
Commit
bcda469f
authored
Jan 04, 2021
by
Yuge Zhang
Committed by
GitHub
Jan 04, 2021
Browse files
Update on NAS examples (#3240)
parent
3423117d
Changes
14
Show whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
15 additions
and
14 deletions
+15
-14
examples/nas/__init__.py
examples/nas/__init__.py
+0
-0
examples/nas/cdarts/utils.py
examples/nas/cdarts/utils.py
+1
-1
examples/nas/classic_nas-tf/train.py
examples/nas/classic_nas-tf/train.py
+1
-1
examples/nas/enas-tf/search.py
examples/nas/enas-tf/search.py
+1
-1
examples/nas/naive-tf/train.py
examples/nas/naive-tf/train.py
+2
-2
examples/nas/proxylessnas/main.py
examples/nas/proxylessnas/main.py
+2
-0
examples/nas/proxylessnas/putils.py
examples/nas/proxylessnas/putils.py
+1
-1
examples/nas/search_space_zoo/darts_example.py
examples/nas/search_space_zoo/darts_example.py
+1
-1
examples/nas/search_space_zoo/enas_macro_example.py
examples/nas/search_space_zoo/enas_macro_example.py
+1
-1
examples/nas/search_space_zoo/enas_micro_example.py
examples/nas/search_space_zoo/enas_micro_example.py
+1
-1
examples/nas/search_space_zoo/nasbench201.py
examples/nas/search_space_zoo/nasbench201.py
+2
-2
examples/nas/spos/utils.py
examples/nas/spos/utils.py
+1
-1
examples/nas/textnas/run_retrain.sh
examples/nas/textnas/run_retrain.sh
+0
-1
examples/nas/textnas/utils.py
examples/nas/textnas/utils.py
+1
-1
No files found.
examples/nas/__init__.py
deleted
100644 → 0
View file @
3423117d
examples/nas/cdarts/utils.py
View file @
bcda469f
...
...
@@ -14,7 +14,7 @@ import torch.nn as nn
from
genotypes
import
Genotype
from
ops
import
PRIMITIVES
from
nni.nas.pytorch.cdarts.utils
import
*
from
nni.
algorithms.
nas.pytorch.cdarts.utils
import
*
def
get_logger
(
file_path
):
...
...
examples/nas/classic_nas-tf/train.py
View file @
bcda469f
...
...
@@ -7,7 +7,7 @@ from tensorflow.keras.optimizers import SGD
import
nni
from
nni.nas.tensorflow.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.tensorflow.classic_nas
import
get_and_apply_next_architecture
from
nni.
algorithms.
nas.tensorflow.classic_nas
import
get_and_apply_next_architecture
tf
.
get_logger
().
setLevel
(
'ERROR'
)
...
...
examples/nas/enas-tf/search.py
View file @
bcda469f
...
...
@@ -5,7 +5,7 @@
from
tensorflow.keras.losses
import
Reduction
,
SparseCategoricalCrossentropy
from
tensorflow.keras.optimizers
import
SGD
from
nni.nas.tensorflow
import
enas
from
nni.
algorithms.
nas.tensorflow
import
enas
import
datasets
from
macro
import
GeneralNetwork
...
...
examples/nas/naive-tf/train.py
View file @
bcda469f
...
...
@@ -8,7 +8,7 @@ from tensorflow.keras.losses import Reduction, SparseCategoricalCrossentropy
from
tensorflow.keras.optimizers
import
SGD
from
nni.nas.tensorflow.mutables
import
LayerChoice
,
InputChoice
from
nni.nas.tensorflow.enas
import
EnasTrainer
from
nni.
algorithms.
nas.tensorflow.enas
import
EnasTrainer
class
Net
(
Model
):
...
...
@@ -55,7 +55,7 @@ class Net(Model):
def
accuracy
(
truth
,
logits
):
truth
=
tf
.
reshape
(
truth
,
-
1
)
truth
=
tf
.
reshape
(
truth
,
(
-
1
,
)
)
predicted
=
tf
.
cast
(
tf
.
math
.
argmax
(
logits
,
axis
=
1
),
truth
.
dtype
)
equal
=
tf
.
cast
(
predicted
==
truth
,
tf
.
int32
)
return
tf
.
math
.
reduce_sum
(
equal
).
numpy
()
/
equal
.
shape
[
0
]
...
...
examples/nas/proxylessnas/main.py
View file @
bcda469f
import
json
import
logging
import
os
import
sys
...
...
@@ -102,6 +103,7 @@ if __name__ == "__main__":
log_frequency
=
10
)
trainer
.
fit
()
print
(
'Final architecture:'
,
trainer
.
export
())
json
.
dump
(
trainer
.
export
(),
open
(
'checkpoint.json'
,
'w'
))
elif
args
.
train_mode
==
'search_v1'
:
# this is architecture search
logger
.
info
(
'Creating ProxylessNasTrainer...'
)
...
...
examples/nas/proxylessnas/putils.py
View file @
bcda469f
...
...
@@ -85,7 +85,7 @@ def accuracy(output, target, topk=(1,)):
res
=
dict
()
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
correct_k
=
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
)
res
[
"acc{}"
.
format
(
k
)]
=
correct_k
.
mul_
(
1.0
/
batch_size
).
item
()
return
res
...
...
examples/nas/search_space_zoo/darts_example.py
View file @
bcda469f
...
...
@@ -10,7 +10,7 @@ import torch.nn as nn
import
datasets
from
nni.nas.pytorch.callbacks
import
ArchitectureCheckpoint
,
LRSchedulerCallback
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.
algorithms.
nas.pytorch.darts
import
DartsTrainer
from
utils
import
accuracy
from
nni.nas.pytorch.search_space_zoo
import
DartsCell
...
...
examples/nas/search_space_zoo/enas_macro_example.py
View file @
bcda469f
...
...
@@ -8,7 +8,7 @@ from torchvision import transforms
from
torchvision.datasets
import
CIFAR10
from
nni.nas.pytorch
import
mutables
from
nni.nas.pytorch
import
enas
from
nni.
algorithms.
nas.pytorch
import
enas
from
utils
import
accuracy
,
reward_accuracy
from
nni.nas.pytorch.callbacks
import
(
ArchitectureCheckpoint
,
LRSchedulerCallback
)
...
...
examples/nas/search_space_zoo/enas_micro_example.py
View file @
bcda469f
...
...
@@ -7,7 +7,7 @@ from argparse import ArgumentParser
from
torchvision
import
transforms
from
torchvision.datasets
import
CIFAR10
from
nni.nas.pytorch
import
enas
from
nni.
algorithms.
nas.pytorch
import
enas
from
utils
import
accuracy
,
reward_accuracy
from
nni.nas.pytorch.callbacks
import
(
ArchitectureCheckpoint
,
LRSchedulerCallback
)
...
...
examples/nas/search_space_zoo/nasbench201.py
View file @
bcda469f
...
...
@@ -10,13 +10,13 @@ import torch
import
torch.nn
as
nn
import
torch.optim
as
optim
from
torch.utils.data
import
DataLoader
from
nni.nas.pytorch
import
enas
from
nni.algorithms.nas.pytorch.darts
import
DartsTrainer
from
nni.algorithms.nas.pytorch
import
enas
from
nni.nas.pytorch.utils
import
AverageMeterGroup
from
nni.nas.pytorch.nasbench201
import
NASBench201Cell
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.benchmarks.nasbench201
import
query_nb201_trial_stats
from
nni.nas.pytorch.callbacks
import
ArchitectureCheckpoint
,
LRSchedulerCallback
from
nni.nas.pytorch.darts
import
DartsTrainer
from
utils
import
accuracy
,
reward_accuracy
import
datasets
...
...
examples/nas/spos/utils.py
View file @
bcda469f
...
...
@@ -36,6 +36,6 @@ def accuracy(output, target, topk=(1, 5)):
res
=
dict
()
for
k
in
topk
:
correct_k
=
correct
[:
k
].
view
(
-
1
).
float
().
sum
(
0
)
correct_k
=
correct
[:
k
].
reshape
(
-
1
).
float
().
sum
(
0
)
res
[
"acc{}"
.
format
(
k
)]
=
correct_k
.
mul_
(
1.0
/
batch_size
).
item
()
return
res
examples/nas/textnas/run_retrain.sh
View file @
bcda469f
...
...
@@ -2,7 +2,6 @@
# Licensed under the MIT license.
export
PYTHONPATH
=
"
$(
pwd
)
"
export
CUDA_VISIBLE_DEVICES
=
0
python3
-u
retrain.py
\
--train_ratio
=
1.0
\
...
...
examples/nas/textnas/utils.py
View file @
bcda469f
...
...
@@ -14,7 +14,7 @@ logger = logging.getLogger("nni.textnas")
def
get_length
(
mask
):
length
=
torch
.
sum
(
mask
,
1
)
length
=
length
.
long
()
length
=
length
.
long
()
.
cpu
()
return
length
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment