Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
afe42cea
Unverified
Commit
afe42cea
authored
Oct 12, 2021
by
Yuge Zhang
Committed by
GitHub
Oct 12, 2021
Browse files
NAS benchmark integration (stage 2) - download (#4205)
parent
000de04b
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
170 additions
and
35 deletions
+170
-35
nni/nas/benchmarks/__init__.py
nni/nas/benchmarks/__init__.py
+1
-0
nni/nas/benchmarks/constants.py
nni/nas/benchmarks/constants.py
+18
-2
nni/nas/benchmarks/nasbench101/db_gen.py
nni/nas/benchmarks/nasbench101/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench101/model.py
nni/nas/benchmarks/nasbench101/model.py
+6
-9
nni/nas/benchmarks/nasbench101/query.py
nni/nas/benchmarks/nasbench101/query.py
+7
-1
nni/nas/benchmarks/nasbench201/db_gen.py
nni/nas/benchmarks/nasbench201/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench201/model.py
nni/nas/benchmarks/nasbench201/model.py
+6
-9
nni/nas/benchmarks/nasbench201/query.py
nni/nas/benchmarks/nasbench201/query.py
+7
-1
nni/nas/benchmarks/nds/db_gen.py
nni/nas/benchmarks/nds/db_gen.py
+4
-1
nni/nas/benchmarks/nds/model.py
nni/nas/benchmarks/nds/model.py
+6
-9
nni/nas/benchmarks/nds/query.py
nni/nas/benchmarks/nds/query.py
+7
-1
nni/nas/benchmarks/utils.py
nni/nas/benchmarks/utils.py
+100
-0
No files found.
nni/nas/benchmarks/__init__.py
View file @
afe42cea
from
.utils
import
load_benchmark
,
download_benchmark
nni/nas/benchmarks/constants.py
View file @
afe42cea
import
os
# TODO: need to be refactored to support automatic download
ENV_NNI_HOME
=
'NNI_HOME'
ENV_XDG_CACHE_HOME
=
'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
DATABASE_DIR
=
os
.
environ
.
get
(
"NASBENCHMARK_DIR"
,
os
.
path
.
expanduser
(
"~/.nni/nasbenchmark"
))
def
_get_nasbenchmark_dir
():
nni_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_NNI_HOME
,
os
.
path
.
join
(
os
.
getenv
(
ENV_XDG_CACHE_HOME
,
DEFAULT_CACHE_DIR
),
'nni'
)))
return
os
.
path
.
join
(
nni_home
,
'nasbenchmark'
)
DATABASE_DIR
=
_get_nasbenchmark_dir
()
DB_URLS
=
{
'nasbench101'
:
'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db'
,
'nasbench201'
:
'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db'
,
'nds'
:
'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db'
}
nni/nas/benchmarks/nasbench101/db_gen.py
View file @
afe42cea
...
...
@@ -3,7 +3,8 @@ import argparse
from
tqdm
import
tqdm
from
nasbench
import
api
# pylint: disable=import-error
from
.model
import
db
,
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
from
.graph_util
import
nasbench_format_to_architecture_repr
,
hash_module
...
...
@@ -13,6 +14,8 @@ def main():
help
=
'Path to the file to be converted, e.g., nasbench_full.tfrecord'
)
args
=
parser
.
parse_args
()
nasbench
=
api
.
NASBench
(
args
.
input_file
)
db
=
load_benchmark
(
'nasbench101'
)
with
db
:
db
.
create_tables
([
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
])
for
hashval
in
tqdm
(
nasbench
.
hash_iterator
(),
desc
=
'Dumping data into database'
):
...
...
nni/nas/benchmarks/nasbench101/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nasbench101.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
Nb101TrialConfig
(
Model
):
...
...
@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
num_epochs
=
IntegerField
(
index
=
True
)
class
Meta
:
database
=
db
database
=
proxy
class
Nb101TrialStats
(
Model
):
...
...
@@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
training_time
=
FloatField
()
class
Meta
:
database
=
db
database
=
proxy
class
Nb101IntermediateStats
(
Model
):
...
...
@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
training_time
=
FloatField
()
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nasbench101/query.py
View file @
afe42cea
...
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
Nb101TrialStats
,
Nb101TrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb101TrialStats
,
Nb101TrialConfig
,
proxy
from
.graph_util
import
hash_module
,
infer_num_vertices
...
...
@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
where each of them has been converted into a dict.
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nasbench101'
))
fields
=
[]
if
reduction
==
'none'
:
reduction
=
None
...
...
nni/nas/benchmarks/nasbench201/db_gen.py
View file @
afe42cea
...
...
@@ -4,8 +4,9 @@ import re
import
tqdm
import
torch
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.model
import
db
,
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
from
.model
import
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
def
parse_arch_str
(
arch_str
):
...
...
@@ -39,6 +40,8 @@ def main():
'imagenet16-120'
:
[
'train'
,
'x-valid'
,
'x-test'
,
'ori-test'
],
}
db
=
load_benchmark
(
'nasbench201'
)
with
db
:
db
.
create_tables
([
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
])
print
(
'Loading NAS-Bench-201 pickle...'
)
...
...
nni/nas/benchmarks/nasbench201/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nasbench201.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
Nb201TrialConfig
(
Model
):
...
...
@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
])
class
Meta
:
database
=
db
database
=
proxy
class
Nb201TrialStats
(
Model
):
...
...
@@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
ori_test_evaluation_time
=
FloatField
()
class
Meta
:
database
=
db
database
=
proxy
class
Nb201IntermediateStats
(
Model
):
...
...
@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
ori_test_loss
=
FloatField
(
null
=
True
)
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nasbench201/query.py
View file @
afe42cea
...
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
Nb201TrialStats
,
Nb201TrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb201TrialStats
,
Nb201TrialConfig
,
proxy
def
query_nb201_trial_stats
(
arch
,
num_epochs
,
dataset
,
reduction
=
None
,
include_intermediates
=
False
):
...
...
@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
where each of them has been converted into a dict.
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nasbench201'
))
fields
=
[]
if
reduction
==
'none'
:
reduction
=
None
...
...
nni/nas/benchmarks/nds/db_gen.py
View file @
afe42cea
...
...
@@ -5,7 +5,8 @@ import os
import
numpy
as
np
import
tqdm
from
.model
import
db
,
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
def
inject_item
(
db
,
item
,
proposer
,
dataset
,
generator
):
...
...
@@ -120,6 +121,8 @@ def main():
'Vanilla_rng3.json'
]
db
=
load_benchmark
(
'nds'
)
with
db
:
db
.
create_tables
([
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
])
for
json_idx
,
json_file
in
enumerate
(
sweep_list
,
start
=
1
):
...
...
nni/nas/benchmarks/nds/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nds.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
NdsTrialConfig
(
Model
):
...
...
@@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
num_epochs
=
IntegerField
()
class
Meta
:
database
=
db
database
=
proxy
class
NdsTrialStats
(
Model
):
...
...
@@ -112,7 +109,7 @@ class NdsTrialStats(Model):
iter_time
=
FloatField
()
class
Meta
:
database
=
db
database
=
proxy
class
NdsIntermediateStats
(
Model
):
...
...
@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
test_acc
=
FloatField
()
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nds/query.py
View file @
afe42cea
...
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
NdsTrialStats
,
NdsTrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
NdsTrialStats
,
NdsTrialConfig
,
proxy
def
query_nds_trial_stats
(
model_family
,
proposer
,
generator
,
model_spec
,
cell_spec
,
dataset
,
...
...
@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict.
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nds'
))
fields
=
[]
if
reduction
==
'none'
:
reduction
=
None
...
...
nni/nas/benchmarks/utils.py
View file @
afe42cea
import
functools
import
hashlib
import
json
import
logging
import
os
import
shutil
import
tempfile
from
pathlib
import
Path
import
requests
import
tqdm
from
playhouse.sqlite_ext
import
SqliteExtDatabase
from
.constants
import
DB_URLS
,
DATABASE_DIR
json_dumps
=
functools
.
partial
(
json
.
dumps
,
sort_keys
=
True
)
# to prevent repetitive loading of benchmarks
_loaded_benchmarks
=
{}
def
load_or_download_file
(
local_path
:
str
,
download_url
:
str
,
download
:
bool
=
False
,
progress
:
bool
=
True
):
f
=
None
hash_prefix
=
Path
(
local_path
).
stem
.
split
(
'-'
)[
-
1
]
_logger
=
logging
.
getLogger
(
__name__
)
try
:
sha256
=
hashlib
.
sha256
()
if
Path
(
local_path
).
exists
():
_logger
.
info
(
'"%s" already exists. Checking hash.'
,
local_path
)
with
Path
(
local_path
).
open
(
'rb'
)
as
fr
:
while
True
:
chunk
=
fr
.
read
(
8192
)
if
len
(
chunk
)
==
0
:
break
sha256
.
update
(
chunk
)
elif
download
:
_logger
.
info
(
'"%s" does not exist. Downloading "%s"'
,
local_path
,
download_url
)
# Follow download implementation in torchvision:
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst_dir
=
Path
(
local_path
).
parent
dst_dir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
dir
=
dst_dir
)
r
=
requests
.
get
(
download_url
,
stream
=
True
)
total_length
=
int
(
r
.
headers
.
get
(
'content-length'
))
with
tqdm
.
tqdm
(
total
=
total_length
,
disable
=
not
progress
,
unit
=
'B'
,
unit_scale
=
True
,
unit_divisor
=
1024
)
as
pbar
:
for
chunk
in
r
.
iter_content
(
8192
):
f
.
write
(
chunk
)
sha256
.
update
(
chunk
)
pbar
.
update
(
len
(
chunk
))
f
.
flush
()
else
:
raise
FileNotFoundError
(
'Download is not enabled, but file still does not exist: {}'
.
format
(
local_path
))
digest
=
sha256
.
hexdigest
()
if
not
digest
.
startswith
(
hash_prefix
):
raise
RuntimeError
(
'Invalid hash value (expected "{}", got "{}")'
.
format
(
hash_prefix
,
digest
))
if
f
is
not
None
:
shutil
.
move
(
f
.
name
,
local_path
)
finally
:
if
f
is
not
None
:
f
.
close
()
if
os
.
path
.
exists
(
f
.
name
):
os
.
remove
(
f
.
name
)
def
load_benchmark
(
benchmark
:
str
)
->
SqliteExtDatabase
:
"""
Load a benchmark as a database.
Parmaeters
----------
benchmark : str
Benchmark name like nasbench201.
"""
if
benchmark
in
_loaded_benchmarks
:
return
_loaded_benchmarks
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
return
_loaded_benchmarks
[
benchmark
]
def
download_benchmark
(
benchmark
:
str
,
progress
:
bool
=
True
):
"""
Download a converted benchmark.
Parameters
----------
benchmark : str
Benchmark name like nasbench201.
"""
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
,
True
,
progress
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment