Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7b653a92
Unverified
Commit
7b653a92
authored
Nov 26, 2019
by
Chi Song
Committed by
GitHub
Nov 26, 2019
Browse files
[NAS] simplify log, and fix a bug on pdarts exporting (#1777)
parent
1398540e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
18 additions
and
46 deletions
+18
-46
docs/en_US/NAS/Overview.md
docs/en_US/NAS/Overview.md
+2
-2
examples/nas/darts/retrain.py
examples/nas/darts/retrain.py
+3
-11
examples/nas/darts/search.py
examples/nas/darts/search.py
+1
-10
examples/nas/enas/search.py
examples/nas/enas/search.py
+3
-10
examples/nas/pdarts/search.py
examples/nas/pdarts/search.py
+1
-8
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+0
-1
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+8
-3
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+0
-1
No files found.
docs/en_US/NAS/Overview.md
View file @
7b653a92
...
...
@@ -51,8 +51,8 @@ cd examples/nas/pdarts
python3 search.py
# train the best architecture, it's the same progress as darts.
cd
examples/nas
/darts
python3 retrain.py
--arc-checkpoint
./checkpoints/epoch_2.json
cd
..
/darts
python3 retrain.py
--arc-checkpoint
.
./pdarts
/checkpoints/epoch_2.json
```
## Use NNI API
...
...
examples/nas/darts/retrain.py
View file @
7b653a92
...
...
@@ -4,24 +4,16 @@ from argparse import ArgumentParser
import
torch
import
torch.nn
as
nn
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
from
torch.utils.tensorboard
import
SummaryWriter
import
datasets
import
utils
from
model
import
CNN
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
logger
=
logging
.
getLogger
()
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
logger
=
logging
.
getLogger
(
'nni'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
writer
=
SummaryWriter
()
...
...
examples/nas/darts/search.py
View file @
7b653a92
...
...
@@ -11,16 +11,7 @@ from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallbac
from
nni.nas.pytorch.darts
import
DartsTrainer
from
utils
import
accuracy
logger
=
logging
.
getLogger
()
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
logger
=
logging
.
getLogger
(
'nni'
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
...
...
examples/nas/enas/search.py
View file @
7b653a92
...
...
@@ -9,19 +9,12 @@ import datasets
from
macro
import
GeneralNetwork
from
micro
import
MicroNetwork
from
nni.nas.pytorch
import
enas
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
,
ArchitectureCheckpoint
from
nni.nas.pytorch.callbacks
import
(
ArchitectureCheckpoint
,
LRSchedulerCallback
)
from
utils
import
accuracy
,
reward_accuracy
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
'nni'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"enas"
)
...
...
examples/nas/pdarts/search.py
View file @
7b653a92
...
...
@@ -19,16 +19,9 @@ if True:
from
model
import
CNN
import
datasets
logger
=
logging
.
getLogger
()
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
logger
=
logging
.
getLogger
(
'nni'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"pdarts"
)
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
7b653a92
...
...
@@ -5,7 +5,6 @@ import torch.nn as nn
from
nni.nas.pytorch.utils
import
global_mutable_counting
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
Mutable
(
nn
.
Module
):
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
7b653a92
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import
json
import
logging
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
,
TorchTensorEncoder
from
.mutator
import
PdartsMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
PdartsTrainer
(
BaseTrainer
):
...
...
@@ -55,7 +55,7 @@ class PdartsTrainer(BaseTrainer):
self
.
trainer
=
DartsTrainer
(
model
,
mutator
=
self
.
mutator
,
loss
=
criterion
,
optimizer
=
optim
,
callbacks
=
darts_callbacks
,
**
self
.
darts_parameters
)
logger
.
info
(
"start pdarts training %s..."
,
epoch
)
logger
.
info
(
"start pdarts training
epoch
%s..."
,
epoch
)
self
.
trainer
.
train
()
...
...
@@ -67,5 +67,10 @@ class PdartsTrainer(BaseTrainer):
def
validate
(
self
):
self
.
model
.
validate
()
def
export
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
)
def
checkpoint
(
self
):
raise
NotImplementedError
(
"Not implemented yet"
)
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
7b653a92
...
...
@@ -7,7 +7,6 @@ import torch
from
.base_trainer
import
BaseTrainer
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment