Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
7b653a92
Unverified
Commit
7b653a92
authored
Nov 26, 2019
by
Chi Song
Committed by
GitHub
Nov 26, 2019
Browse files
[NAS] simplify log, and fix a bug on pdarts exporting (#1777)
parent
1398540e
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
18 additions
and
46 deletions
+18
-46
docs/en_US/NAS/Overview.md
docs/en_US/NAS/Overview.md
+2
-2
examples/nas/darts/retrain.py
examples/nas/darts/retrain.py
+3
-11
examples/nas/darts/search.py
examples/nas/darts/search.py
+1
-10
examples/nas/enas/search.py
examples/nas/enas/search.py
+3
-10
examples/nas/pdarts/search.py
examples/nas/pdarts/search.py
+1
-8
src/sdk/pynni/nni/nas/pytorch/mutables.py
src/sdk/pynni/nni/nas/pytorch/mutables.py
+0
-1
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
+8
-3
src/sdk/pynni/nni/nas/pytorch/trainer.py
src/sdk/pynni/nni/nas/pytorch/trainer.py
+0
-1
No files found.
docs/en_US/NAS/Overview.md
View file @
7b653a92
...
@@ -51,8 +51,8 @@ cd examples/nas/pdarts
...
@@ -51,8 +51,8 @@ cd examples/nas/pdarts
python3 search.py
python3 search.py
# train the best architecture, it's the same progress as darts.
# train the best architecture, it's the same progress as darts.
cd
examples/nas
/darts
cd
..
/darts
python3 retrain.py
--arc-checkpoint
./checkpoints/epoch_2.json
python3 retrain.py
--arc-checkpoint
.
./pdarts
/checkpoints/epoch_2.json
```
```
## Use NNI API
## Use NNI API
...
...
examples/nas/darts/retrain.py
View file @
7b653a92
...
@@ -4,24 +4,16 @@ from argparse import ArgumentParser
...
@@ -4,24 +4,16 @@ from argparse import ArgumentParser
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
from
torch.utils.tensorboard
import
SummaryWriter
from
torch.utils.tensorboard
import
SummaryWriter
import
datasets
import
datasets
import
utils
import
utils
from
model
import
CNN
from
model
import
CNN
from
nni.nas.pytorch.fixed
import
apply_fixed_architecture
from
nni.nas.pytorch.utils
import
AverageMeter
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
'nni'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
device
=
torch
.
device
(
"cuda"
if
torch
.
cuda
.
is_available
()
else
"cpu"
)
writer
=
SummaryWriter
()
writer
=
SummaryWriter
()
...
...
examples/nas/darts/search.py
View file @
7b653a92
...
@@ -11,16 +11,7 @@ from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallbac
...
@@ -11,16 +11,7 @@ from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallbac
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.darts
import
DartsTrainer
from
utils
import
accuracy
from
utils
import
accuracy
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
'nni'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"darts"
)
parser
=
ArgumentParser
(
"darts"
)
...
...
examples/nas/enas/search.py
View file @
7b653a92
...
@@ -9,19 +9,12 @@ import datasets
...
@@ -9,19 +9,12 @@ import datasets
from
macro
import
GeneralNetwork
from
macro
import
GeneralNetwork
from
micro
import
MicroNetwork
from
micro
import
MicroNetwork
from
nni.nas.pytorch
import
enas
from
nni.nas.pytorch
import
enas
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
,
ArchitectureCheckpoint
from
nni.nas.pytorch.callbacks
import
(
ArchitectureCheckpoint
,
LRSchedulerCallback
)
from
utils
import
accuracy
,
reward_accuracy
from
utils
import
accuracy
,
reward_accuracy
logger
=
logging
.
getLogger
()
logger
=
logging
.
getLogger
(
'nni'
)
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"enas"
)
parser
=
ArgumentParser
(
"enas"
)
...
...
examples/nas/pdarts/search.py
View file @
7b653a92
...
@@ -19,16 +19,9 @@ if True:
...
@@ -19,16 +19,9 @@ if True:
from
model
import
CNN
from
model
import
CNN
import
datasets
import
datasets
logger
=
logging
.
getLogger
()
fmt
=
'[%(asctime)s] %(levelname)s (%(name)s/%(threadName)s) %(message)s'
logger
=
logging
.
getLogger
(
'nni'
)
logging
.
Formatter
.
converter
=
time
.
localtime
formatter
=
logging
.
Formatter
(
fmt
,
'%m/%d/%Y, %I:%M:%S %p'
)
std_out_info
=
logging
.
StreamHandler
()
std_out_info
.
setFormatter
(
formatter
)
logger
.
setLevel
(
logging
.
INFO
)
logger
.
addHandler
(
std_out_info
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
parser
=
ArgumentParser
(
"pdarts"
)
parser
=
ArgumentParser
(
"pdarts"
)
...
...
src/sdk/pynni/nni/nas/pytorch/mutables.py
View file @
7b653a92
...
@@ -5,7 +5,6 @@ import torch.nn as nn
...
@@ -5,7 +5,6 @@ import torch.nn as nn
from
nni.nas.pytorch.utils
import
global_mutable_counting
from
nni.nas.pytorch.utils
import
global_mutable_counting
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
Mutable
(
nn
.
Module
):
class
Mutable
(
nn
.
Module
):
...
...
src/sdk/pynni/nni/nas/pytorch/pdarts/trainer.py
View file @
7b653a92
# Copyright (c) Microsoft Corporation.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Licensed under the MIT license.
import
json
import
logging
import
logging
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
from
nni.nas.pytorch.callbacks
import
LRSchedulerCallback
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.darts
import
DartsTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
from
nni.nas.pytorch.trainer
import
BaseTrainer
,
TorchTensorEncoder
from
.mutator
import
PdartsMutator
from
.mutator
import
PdartsMutator
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
logger
.
setLevel
(
logging
.
INFO
)
class
PdartsTrainer
(
BaseTrainer
):
class
PdartsTrainer
(
BaseTrainer
):
...
@@ -55,7 +55,7 @@ class PdartsTrainer(BaseTrainer):
...
@@ -55,7 +55,7 @@ class PdartsTrainer(BaseTrainer):
self
.
trainer
=
DartsTrainer
(
model
,
mutator
=
self
.
mutator
,
loss
=
criterion
,
optimizer
=
optim
,
self
.
trainer
=
DartsTrainer
(
model
,
mutator
=
self
.
mutator
,
loss
=
criterion
,
optimizer
=
optim
,
callbacks
=
darts_callbacks
,
**
self
.
darts_parameters
)
callbacks
=
darts_callbacks
,
**
self
.
darts_parameters
)
logger
.
info
(
"start pdarts training %s..."
,
epoch
)
logger
.
info
(
"start pdarts training
epoch
%s..."
,
epoch
)
self
.
trainer
.
train
()
self
.
trainer
.
train
()
...
@@ -67,5 +67,10 @@ class PdartsTrainer(BaseTrainer):
...
@@ -67,5 +67,10 @@ class PdartsTrainer(BaseTrainer):
def
validate
(
self
):
def
validate
(
self
):
self
.
model
.
validate
()
self
.
model
.
validate
()
def
export
(
self
,
file
):
mutator_export
=
self
.
mutator
.
export
()
with
open
(
file
,
"w"
)
as
f
:
json
.
dump
(
mutator_export
,
f
,
indent
=
2
,
sort_keys
=
True
,
cls
=
TorchTensorEncoder
)
def
checkpoint
(
self
):
def
checkpoint
(
self
):
raise
NotImplementedError
(
"Not implemented yet"
)
raise
NotImplementedError
(
"Not implemented yet"
)
src/sdk/pynni/nni/nas/pytorch/trainer.py
View file @
7b653a92
...
@@ -7,7 +7,6 @@ import torch
...
@@ -7,7 +7,6 @@ import torch
from
.base_trainer
import
BaseTrainer
from
.base_trainer
import
BaseTrainer
_logger
=
logging
.
getLogger
(
__name__
)
_logger
=
logging
.
getLogger
(
__name__
)
_logger
.
setLevel
(
logging
.
INFO
)
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
class
TorchTensorEncoder
(
json
.
JSONEncoder
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment