Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
afe42cea
Unverified
Commit
afe42cea
authored
Oct 12, 2021
by
Yuge Zhang
Committed by
GitHub
Oct 12, 2021
Browse files
NAS benchmark integration (stage 2) - download (#4205)
parent
000de04b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
170 additions
and
35 deletions
+170
-35
nni/nas/benchmarks/__init__.py
nni/nas/benchmarks/__init__.py
+1
-0
nni/nas/benchmarks/constants.py
nni/nas/benchmarks/constants.py
+18
-2
nni/nas/benchmarks/nasbench101/db_gen.py
nni/nas/benchmarks/nasbench101/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench101/model.py
nni/nas/benchmarks/nasbench101/model.py
+6
-9
nni/nas/benchmarks/nasbench101/query.py
nni/nas/benchmarks/nasbench101/query.py
+7
-1
nni/nas/benchmarks/nasbench201/db_gen.py
nni/nas/benchmarks/nasbench201/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench201/model.py
nni/nas/benchmarks/nasbench201/model.py
+6
-9
nni/nas/benchmarks/nasbench201/query.py
nni/nas/benchmarks/nasbench201/query.py
+7
-1
nni/nas/benchmarks/nds/db_gen.py
nni/nas/benchmarks/nds/db_gen.py
+4
-1
nni/nas/benchmarks/nds/model.py
nni/nas/benchmarks/nds/model.py
+6
-9
nni/nas/benchmarks/nds/query.py
nni/nas/benchmarks/nds/query.py
+7
-1
nni/nas/benchmarks/utils.py
nni/nas/benchmarks/utils.py
+100
-0
No files found.
nni/nas/benchmarks/__init__.py
View file @
afe42cea
from
.utils
import
load_benchmark
,
download_benchmark
nni/nas/benchmarks/constants.py
View file @
afe42cea
import
os
import
os
# TODO: need to be refactored to support automatic download
ENV_NNI_HOME
=
'NNI_HOME'
ENV_XDG_CACHE_HOME
=
'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
DATABASE_DIR
=
os
.
environ
.
get
(
"NASBENCHMARK_DIR"
,
os
.
path
.
expanduser
(
"~/.nni/nasbenchmark"
))
def
_get_nasbenchmark_dir
():
nni_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_NNI_HOME
,
os
.
path
.
join
(
os
.
getenv
(
ENV_XDG_CACHE_HOME
,
DEFAULT_CACHE_DIR
),
'nni'
)))
return
os
.
path
.
join
(
nni_home
,
'nasbenchmark'
)
DATABASE_DIR
=
_get_nasbenchmark_dir
()
DB_URLS
=
{
'nasbench101'
:
'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db'
,
'nasbench201'
:
'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db'
,
'nds'
:
'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db'
}
nni/nas/benchmarks/nasbench101/db_gen.py
View file @
afe42cea
...
@@ -3,7 +3,8 @@ import argparse
...
@@ -3,7 +3,8 @@ import argparse
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
nasbench
import
api
# pylint: disable=import-error
from
nasbench
import
api
# pylint: disable=import-error
from
.model
import
db
,
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
from
.graph_util
import
nasbench_format_to_architecture_repr
,
hash_module
from
.graph_util
import
nasbench_format_to_architecture_repr
,
hash_module
...
@@ -13,6 +14,8 @@ def main():
...
@@ -13,6 +14,8 @@ def main():
help
=
'Path to the file to be converted, e.g., nasbench_full.tfrecord'
)
help
=
'Path to the file to be converted, e.g., nasbench_full.tfrecord'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
nasbench
=
api
.
NASBench
(
args
.
input_file
)
nasbench
=
api
.
NASBench
(
args
.
input_file
)
db
=
load_benchmark
(
'nasbench101'
)
with
db
:
with
db
:
db
.
create_tables
([
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
])
db
.
create_tables
([
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
])
for
hashval
in
tqdm
(
nasbench
.
hash_iterator
(),
desc
=
'Dumping data into database'
):
for
hashval
in
tqdm
(
nasbench
.
hash_iterator
(),
desc
=
'Dumping data into database'
):
...
...
nni/nas/benchmarks/nasbench101/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nasbench101.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
Nb101TrialConfig
(
Model
):
class
Nb101TrialConfig
(
Model
):
...
@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
...
@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
num_epochs
=
IntegerField
(
index
=
True
)
num_epochs
=
IntegerField
(
index
=
True
)
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb101TrialStats
(
Model
):
class
Nb101TrialStats
(
Model
):
...
@@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
...
@@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
training_time
=
FloatField
()
training_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb101IntermediateStats
(
Model
):
class
Nb101IntermediateStats
(
Model
):
...
@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
...
@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
training_time
=
FloatField
()
training_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nasbench101/query.py
View file @
afe42cea
...
@@ -2,7 +2,9 @@ import functools
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
Nb101TrialStats
,
Nb101TrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb101TrialStats
,
Nb101TrialConfig
,
proxy
from
.graph_util
import
hash_module
,
infer_num_vertices
from
.graph_util
import
hash_module
,
infer_num_vertices
...
@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
...
@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
where each of them has been converted into a dict.
where each of them has been converted into a dict.
"""
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nasbench101'
))
fields
=
[]
fields
=
[]
if
reduction
==
'none'
:
if
reduction
==
'none'
:
reduction
=
None
reduction
=
None
...
...
nni/nas/benchmarks/nasbench201/db_gen.py
View file @
afe42cea
...
@@ -4,8 +4,9 @@ import re
...
@@ -4,8 +4,9 @@ import re
import
tqdm
import
tqdm
import
torch
import
torch
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.model
import
db
,
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
from
.model
import
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
def
parse_arch_str
(
arch_str
):
def
parse_arch_str
(
arch_str
):
...
@@ -39,6 +40,8 @@ def main():
...
@@ -39,6 +40,8 @@ def main():
'imagenet16-120'
:
[
'train'
,
'x-valid'
,
'x-test'
,
'ori-test'
],
'imagenet16-120'
:
[
'train'
,
'x-valid'
,
'x-test'
,
'ori-test'
],
}
}
db
=
load_benchmark
(
'nasbench201'
)
with
db
:
with
db
:
db
.
create_tables
([
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
])
db
.
create_tables
([
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
])
print
(
'Loading NAS-Bench-201 pickle...'
)
print
(
'Loading NAS-Bench-201 pickle...'
)
...
...
nni/nas/benchmarks/nasbench201/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nasbench201.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
Nb201TrialConfig
(
Model
):
class
Nb201TrialConfig
(
Model
):
...
@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
...
@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
])
])
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb201TrialStats
(
Model
):
class
Nb201TrialStats
(
Model
):
...
@@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
...
@@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
ori_test_evaluation_time
=
FloatField
()
ori_test_evaluation_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb201IntermediateStats
(
Model
):
class
Nb201IntermediateStats
(
Model
):
...
@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
...
@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
ori_test_loss
=
FloatField
(
null
=
True
)
ori_test_loss
=
FloatField
(
null
=
True
)
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nasbench201/query.py
View file @
afe42cea
...
@@ -2,7 +2,9 @@ import functools
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
Nb201TrialStats
,
Nb201TrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb201TrialStats
,
Nb201TrialConfig
,
proxy
def
query_nb201_trial_stats
(
arch
,
num_epochs
,
dataset
,
reduction
=
None
,
include_intermediates
=
False
):
def
query_nb201_trial_stats
(
arch
,
num_epochs
,
dataset
,
reduction
=
None
,
include_intermediates
=
False
):
...
@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
...
@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
where each of them has been converted into a dict.
where each of them has been converted into a dict.
"""
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nasbench201'
))
fields
=
[]
fields
=
[]
if
reduction
==
'none'
:
if
reduction
==
'none'
:
reduction
=
None
reduction
=
None
...
...
nni/nas/benchmarks/nds/db_gen.py
View file @
afe42cea
...
@@ -5,7 +5,8 @@ import os
...
@@ -5,7 +5,8 @@ import os
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
tqdm
from
.model
import
db
,
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
def
inject_item
(
db
,
item
,
proposer
,
dataset
,
generator
):
def
inject_item
(
db
,
item
,
proposer
,
dataset
,
generator
):
...
@@ -120,6 +121,8 @@ def main():
...
@@ -120,6 +121,8 @@ def main():
'Vanilla_rng3.json'
'Vanilla_rng3.json'
]
]
db
=
load_benchmark
(
'nds'
)
with
db
:
with
db
:
db
.
create_tables
([
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
])
db
.
create_tables
([
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
])
for
json_idx
,
json_file
in
enumerate
(
sweep_list
,
start
=
1
):
for
json_idx
,
json_file
in
enumerate
(
sweep_list
,
start
=
1
):
...
...
nni/nas/benchmarks/nds/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nds.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
NdsTrialConfig
(
Model
):
class
NdsTrialConfig
(
Model
):
...
@@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
...
@@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
num_epochs
=
IntegerField
()
num_epochs
=
IntegerField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
NdsTrialStats
(
Model
):
class
NdsTrialStats
(
Model
):
...
@@ -112,7 +109,7 @@ class NdsTrialStats(Model):
...
@@ -112,7 +109,7 @@ class NdsTrialStats(Model):
iter_time
=
FloatField
()
iter_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
NdsIntermediateStats
(
Model
):
class
NdsIntermediateStats
(
Model
):
...
@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
...
@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
test_acc
=
FloatField
()
test_acc
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nds/query.py
View file @
afe42cea
...
@@ -2,7 +2,9 @@ import functools
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
NdsTrialStats
,
NdsTrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
NdsTrialStats
,
NdsTrialConfig
,
proxy
def
query_nds_trial_stats
(
model_family
,
proposer
,
generator
,
model_spec
,
cell_spec
,
dataset
,
def
query_nds_trial_stats
(
model_family
,
proposer
,
generator
,
model_spec
,
cell_spec
,
dataset
,
...
@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
...
@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict.
where each of them has been converted into a dict.
"""
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nds'
))
fields
=
[]
fields
=
[]
if
reduction
==
'none'
:
if
reduction
==
'none'
:
reduction
=
None
reduction
=
None
...
...
nni/nas/benchmarks/utils.py
View file @
afe42cea
import
functools
import
functools
import
hashlib
import
json
import
json
import
logging
import
os
import
shutil
import
tempfile
from
pathlib
import
Path
import
requests
import
tqdm
from
playhouse.sqlite_ext
import
SqliteExtDatabase
from
.constants
import
DB_URLS
,
DATABASE_DIR
json_dumps
=
functools
.
partial
(
json
.
dumps
,
sort_keys
=
True
)
json_dumps
=
functools
.
partial
(
json
.
dumps
,
sort_keys
=
True
)
# to prevent repetitive loading of benchmarks
_loaded_benchmarks
=
{}
def
load_or_download_file
(
local_path
:
str
,
download_url
:
str
,
download
:
bool
=
False
,
progress
:
bool
=
True
):
f
=
None
hash_prefix
=
Path
(
local_path
).
stem
.
split
(
'-'
)[
-
1
]
_logger
=
logging
.
getLogger
(
__name__
)
try
:
sha256
=
hashlib
.
sha256
()
if
Path
(
local_path
).
exists
():
_logger
.
info
(
'"%s" already exists. Checking hash.'
,
local_path
)
with
Path
(
local_path
).
open
(
'rb'
)
as
fr
:
while
True
:
chunk
=
fr
.
read
(
8192
)
if
len
(
chunk
)
==
0
:
break
sha256
.
update
(
chunk
)
elif
download
:
_logger
.
info
(
'"%s" does not exist. Downloading "%s"'
,
local_path
,
download_url
)
# Follow download implementation in torchvision:
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst_dir
=
Path
(
local_path
).
parent
dst_dir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
dir
=
dst_dir
)
r
=
requests
.
get
(
download_url
,
stream
=
True
)
total_length
=
int
(
r
.
headers
.
get
(
'content-length'
))
with
tqdm
.
tqdm
(
total
=
total_length
,
disable
=
not
progress
,
unit
=
'B'
,
unit_scale
=
True
,
unit_divisor
=
1024
)
as
pbar
:
for
chunk
in
r
.
iter_content
(
8192
):
f
.
write
(
chunk
)
sha256
.
update
(
chunk
)
pbar
.
update
(
len
(
chunk
))
f
.
flush
()
else
:
raise
FileNotFoundError
(
'Download is not enabled, but file still does not exist: {}'
.
format
(
local_path
))
digest
=
sha256
.
hexdigest
()
if
not
digest
.
startswith
(
hash_prefix
):
raise
RuntimeError
(
'Invalid hash value (expected "{}", got "{}")'
.
format
(
hash_prefix
,
digest
))
if
f
is
not
None
:
shutil
.
move
(
f
.
name
,
local_path
)
finally
:
if
f
is
not
None
:
f
.
close
()
if
os
.
path
.
exists
(
f
.
name
):
os
.
remove
(
f
.
name
)
def
load_benchmark
(
benchmark
:
str
)
->
SqliteExtDatabase
:
"""
Load a benchmark as a database.
Parmaeters
----------
benchmark : str
Benchmark name like nasbench201.
"""
if
benchmark
in
_loaded_benchmarks
:
return
_loaded_benchmarks
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
return
_loaded_benchmarks
[
benchmark
]
def
download_benchmark
(
benchmark
:
str
,
progress
:
bool
=
True
):
"""
Download a converted benchmark.
Parameters
----------
benchmark : str
Benchmark name like nasbench201.
"""
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
,
True
,
progress
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment