Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
nni
Commits
afe42cea
"...text-generation-inference.git" did not exist on "2b19d671b4d1020e31276477f278ca87cfa37a3c"
Unverified
Commit
afe42cea
authored
Oct 12, 2021
by
Yuge Zhang
Committed by
GitHub
Oct 12, 2021
Browse files
NAS benchmark integration (stage 2) - download (#4205)
parent
000de04b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
170 additions
and
35 deletions
+170
-35
nni/nas/benchmarks/__init__.py
nni/nas/benchmarks/__init__.py
+1
-0
nni/nas/benchmarks/constants.py
nni/nas/benchmarks/constants.py
+18
-2
nni/nas/benchmarks/nasbench101/db_gen.py
nni/nas/benchmarks/nasbench101/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench101/model.py
nni/nas/benchmarks/nasbench101/model.py
+6
-9
nni/nas/benchmarks/nasbench101/query.py
nni/nas/benchmarks/nasbench101/query.py
+7
-1
nni/nas/benchmarks/nasbench201/db_gen.py
nni/nas/benchmarks/nasbench201/db_gen.py
+4
-1
nni/nas/benchmarks/nasbench201/model.py
nni/nas/benchmarks/nasbench201/model.py
+6
-9
nni/nas/benchmarks/nasbench201/query.py
nni/nas/benchmarks/nasbench201/query.py
+7
-1
nni/nas/benchmarks/nds/db_gen.py
nni/nas/benchmarks/nds/db_gen.py
+4
-1
nni/nas/benchmarks/nds/model.py
nni/nas/benchmarks/nds/model.py
+6
-9
nni/nas/benchmarks/nds/query.py
nni/nas/benchmarks/nds/query.py
+7
-1
nni/nas/benchmarks/utils.py
nni/nas/benchmarks/utils.py
+100
-0
No files found.
nni/nas/benchmarks/__init__.py
View file @
afe42cea
from
.utils
import
load_benchmark
,
download_benchmark
nni/nas/benchmarks/constants.py
View file @
afe42cea
import
os
import
os
# TODO: need to be refactored to support automatic download
ENV_NNI_HOME
=
'NNI_HOME'
ENV_XDG_CACHE_HOME
=
'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR
=
'~/.cache'
DATABASE_DIR
=
os
.
environ
.
get
(
"NASBENCHMARK_DIR"
,
os
.
path
.
expanduser
(
"~/.nni/nasbenchmark"
))
def
_get_nasbenchmark_dir
():
nni_home
=
os
.
path
.
expanduser
(
os
.
getenv
(
ENV_NNI_HOME
,
os
.
path
.
join
(
os
.
getenv
(
ENV_XDG_CACHE_HOME
,
DEFAULT_CACHE_DIR
),
'nni'
)))
return
os
.
path
.
join
(
nni_home
,
'nasbenchmark'
)
DATABASE_DIR
=
_get_nasbenchmark_dir
()
DB_URLS
=
{
'nasbench101'
:
'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db'
,
'nasbench201'
:
'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db'
,
'nds'
:
'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db'
}
nni/nas/benchmarks/nasbench101/db_gen.py
View file @
afe42cea
...
@@ -3,7 +3,8 @@ import argparse
...
@@ -3,7 +3,8 @@ import argparse
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
nasbench
import
api
# pylint: disable=import-error
from
nasbench
import
api
# pylint: disable=import-error
from
.model
import
db
,
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
from
.graph_util
import
nasbench_format_to_architecture_repr
,
hash_module
from
.graph_util
import
nasbench_format_to_architecture_repr
,
hash_module
...
@@ -13,6 +14,8 @@ def main():
...
@@ -13,6 +14,8 @@ def main():
help
=
'Path to the file to be converted, e.g., nasbench_full.tfrecord'
)
help
=
'Path to the file to be converted, e.g., nasbench_full.tfrecord'
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
nasbench
=
api
.
NASBench
(
args
.
input_file
)
nasbench
=
api
.
NASBench
(
args
.
input_file
)
db
=
load_benchmark
(
'nasbench101'
)
with
db
:
with
db
:
db
.
create_tables
([
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
])
db
.
create_tables
([
Nb101TrialConfig
,
Nb101TrialStats
,
Nb101IntermediateStats
])
for
hashval
in
tqdm
(
nasbench
.
hash_iterator
(),
desc
=
'Dumping data into database'
):
for
hashval
in
tqdm
(
nasbench
.
hash_iterator
(),
desc
=
'Dumping data into database'
):
...
...
nni/nas/benchmarks/nasbench101/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nasbench101.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
Nb101TrialConfig
(
Model
):
class
Nb101TrialConfig
(
Model
):
...
@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
...
@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
num_epochs
=
IntegerField
(
index
=
True
)
num_epochs
=
IntegerField
(
index
=
True
)
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb101TrialStats
(
Model
):
class
Nb101TrialStats
(
Model
):
...
@@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
...
@@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
training_time
=
FloatField
()
training_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb101IntermediateStats
(
Model
):
class
Nb101IntermediateStats
(
Model
):
...
@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
...
@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
training_time
=
FloatField
()
training_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nasbench101/query.py
View file @
afe42cea
...
@@ -2,7 +2,9 @@ import functools
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
Nb101TrialStats
,
Nb101TrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb101TrialStats
,
Nb101TrialConfig
,
proxy
from
.graph_util
import
hash_module
,
infer_num_vertices
from
.graph_util
import
hash_module
,
infer_num_vertices
...
@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
...
@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
where each of them has been converted into a dict.
where each of them has been converted into a dict.
"""
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nasbench101'
))
fields
=
[]
fields
=
[]
if
reduction
==
'none'
:
if
reduction
==
'none'
:
reduction
=
None
reduction
=
None
...
...
nni/nas/benchmarks/nasbench201/db_gen.py
View file @
afe42cea
...
@@ -4,8 +4,9 @@ import re
...
@@ -4,8 +4,9 @@ import re
import
tqdm
import
tqdm
import
torch
import
torch
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.constants
import
NONE
,
SKIP_CONNECT
,
CONV_1X1
,
CONV_3X3
,
AVG_POOL_3X3
from
.model
import
db
,
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
from
.model
import
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
def
parse_arch_str
(
arch_str
):
def
parse_arch_str
(
arch_str
):
...
@@ -39,6 +40,8 @@ def main():
...
@@ -39,6 +40,8 @@ def main():
'imagenet16-120'
:
[
'train'
,
'x-valid'
,
'x-test'
,
'ori-test'
],
'imagenet16-120'
:
[
'train'
,
'x-valid'
,
'x-test'
,
'ori-test'
],
}
}
db
=
load_benchmark
(
'nasbench201'
)
with
db
:
with
db
:
db
.
create_tables
([
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
])
db
.
create_tables
([
Nb201TrialConfig
,
Nb201TrialStats
,
Nb201IntermediateStats
])
print
(
'Loading NAS-Bench-201 pickle...'
)
print
(
'Loading NAS-Bench-201 pickle...'
)
...
...
nni/nas/benchmarks/nasbench201/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nasbench201.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
Nb201TrialConfig
(
Model
):
class
Nb201TrialConfig
(
Model
):
...
@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
...
@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
])
])
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb201TrialStats
(
Model
):
class
Nb201TrialStats
(
Model
):
...
@@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
...
@@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
ori_test_evaluation_time
=
FloatField
()
ori_test_evaluation_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
Nb201IntermediateStats
(
Model
):
class
Nb201IntermediateStats
(
Model
):
...
@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
...
@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
ori_test_loss
=
FloatField
(
null
=
True
)
ori_test_loss
=
FloatField
(
null
=
True
)
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nasbench201/query.py
View file @
afe42cea
...
@@ -2,7 +2,9 @@ import functools
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
Nb201TrialStats
,
Nb201TrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
Nb201TrialStats
,
Nb201TrialConfig
,
proxy
def
query_nb201_trial_stats
(
arch
,
num_epochs
,
dataset
,
reduction
=
None
,
include_intermediates
=
False
):
def
query_nb201_trial_stats
(
arch
,
num_epochs
,
dataset
,
reduction
=
None
,
include_intermediates
=
False
):
...
@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
...
@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
where each of them has been converted into a dict.
where each of them has been converted into a dict.
"""
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nasbench201'
))
fields
=
[]
fields
=
[]
if
reduction
==
'none'
:
if
reduction
==
'none'
:
reduction
=
None
reduction
=
None
...
...
nni/nas/benchmarks/nds/db_gen.py
View file @
afe42cea
...
@@ -5,7 +5,8 @@ import os
...
@@ -5,7 +5,8 @@ import os
import
numpy
as
np
import
numpy
as
np
import
tqdm
import
tqdm
from
.model
import
db
,
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
def
inject_item
(
db
,
item
,
proposer
,
dataset
,
generator
):
def
inject_item
(
db
,
item
,
proposer
,
dataset
,
generator
):
...
@@ -120,6 +121,8 @@ def main():
...
@@ -120,6 +121,8 @@ def main():
'Vanilla_rng3.json'
'Vanilla_rng3.json'
]
]
db
=
load_benchmark
(
'nds'
)
with
db
:
with
db
:
db
.
create_tables
([
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
])
db
.
create_tables
([
NdsTrialConfig
,
NdsTrialStats
,
NdsIntermediateStats
])
for
json_idx
,
json_file
in
enumerate
(
sweep_list
,
start
=
1
):
for
json_idx
,
json_file
in
enumerate
(
sweep_list
,
start
=
1
):
...
...
nni/nas/benchmarks/nds/model.py
View file @
afe42cea
import
os
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
,
Proxy
from
playhouse.sqlite_ext
import
JSONField
from
peewee
import
CharField
,
FloatField
,
ForeignKeyField
,
IntegerField
,
Model
from
playhouse.sqlite_ext
import
JSONField
,
SqliteExtDatabase
from
nni.nas.benchmarks.constants
import
DATABASE_DIR
from
nni.nas.benchmarks.utils
import
json_dumps
from
nni.nas.benchmarks.utils
import
json_dumps
db
=
SqliteExtDatabase
(
os
.
path
.
join
(
DATABASE_DIR
,
'nds.db'
),
autoconnect
=
True
)
proxy
=
Proxy
(
)
class
NdsTrialConfig
(
Model
):
class
NdsTrialConfig
(
Model
):
...
@@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
...
@@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
num_epochs
=
IntegerField
()
num_epochs
=
IntegerField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
NdsTrialStats
(
Model
):
class
NdsTrialStats
(
Model
):
...
@@ -112,7 +109,7 @@ class NdsTrialStats(Model):
...
@@ -112,7 +109,7 @@ class NdsTrialStats(Model):
iter_time
=
FloatField
()
iter_time
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
class
NdsIntermediateStats
(
Model
):
class
NdsIntermediateStats
(
Model
):
...
@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
...
@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
test_acc
=
FloatField
()
test_acc
=
FloatField
()
class
Meta
:
class
Meta
:
database
=
db
database
=
proxy
nni/nas/benchmarks/nds/query.py
View file @
afe42cea
...
@@ -2,7 +2,9 @@ import functools
...
@@ -2,7 +2,9 @@ import functools
from
peewee
import
fn
from
peewee
import
fn
from
playhouse.shortcuts
import
model_to_dict
from
playhouse.shortcuts
import
model_to_dict
from
.model
import
NdsTrialStats
,
NdsTrialConfig
from
nni.nas.benchmarks.utils
import
load_benchmark
from
.model
import
NdsTrialStats
,
NdsTrialConfig
,
proxy
def
query_nds_trial_stats
(
model_family
,
proposer
,
generator
,
model_spec
,
cell_spec
,
dataset
,
def
query_nds_trial_stats
(
model_family
,
proposer
,
generator
,
model_spec
,
cell_spec
,
dataset
,
...
@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
...
@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict.
where each of them has been converted into a dict.
"""
"""
if
proxy
.
obj
is
None
:
proxy
.
initialize
(
load_benchmark
(
'nds'
))
fields
=
[]
fields
=
[]
if
reduction
==
'none'
:
if
reduction
==
'none'
:
reduction
=
None
reduction
=
None
...
...
nni/nas/benchmarks/utils.py
View file @
afe42cea
import
functools
import
functools
import
hashlib
import
json
import
json
import
logging
import
os
import
shutil
import
tempfile
from
pathlib
import
Path
import
requests
import
tqdm
from
playhouse.sqlite_ext
import
SqliteExtDatabase
from
.constants
import
DB_URLS
,
DATABASE_DIR
json_dumps
=
functools
.
partial
(
json
.
dumps
,
sort_keys
=
True
)
json_dumps
=
functools
.
partial
(
json
.
dumps
,
sort_keys
=
True
)
# to prevent repetitive loading of benchmarks
_loaded_benchmarks
=
{}
def
load_or_download_file
(
local_path
:
str
,
download_url
:
str
,
download
:
bool
=
False
,
progress
:
bool
=
True
):
f
=
None
hash_prefix
=
Path
(
local_path
).
stem
.
split
(
'-'
)[
-
1
]
_logger
=
logging
.
getLogger
(
__name__
)
try
:
sha256
=
hashlib
.
sha256
()
if
Path
(
local_path
).
exists
():
_logger
.
info
(
'"%s" already exists. Checking hash.'
,
local_path
)
with
Path
(
local_path
).
open
(
'rb'
)
as
fr
:
while
True
:
chunk
=
fr
.
read
(
8192
)
if
len
(
chunk
)
==
0
:
break
sha256
.
update
(
chunk
)
elif
download
:
_logger
.
info
(
'"%s" does not exist. Downloading "%s"'
,
local_path
,
download_url
)
# Follow download implementation in torchvision:
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst_dir
=
Path
(
local_path
).
parent
dst_dir
.
mkdir
(
exist_ok
=
True
,
parents
=
True
)
f
=
tempfile
.
NamedTemporaryFile
(
delete
=
False
,
dir
=
dst_dir
)
r
=
requests
.
get
(
download_url
,
stream
=
True
)
total_length
=
int
(
r
.
headers
.
get
(
'content-length'
))
with
tqdm
.
tqdm
(
total
=
total_length
,
disable
=
not
progress
,
unit
=
'B'
,
unit_scale
=
True
,
unit_divisor
=
1024
)
as
pbar
:
for
chunk
in
r
.
iter_content
(
8192
):
f
.
write
(
chunk
)
sha256
.
update
(
chunk
)
pbar
.
update
(
len
(
chunk
))
f
.
flush
()
else
:
raise
FileNotFoundError
(
'Download is not enabled, but file still does not exist: {}'
.
format
(
local_path
))
digest
=
sha256
.
hexdigest
()
if
not
digest
.
startswith
(
hash_prefix
):
raise
RuntimeError
(
'Invalid hash value (expected "{}", got "{}")'
.
format
(
hash_prefix
,
digest
))
if
f
is
not
None
:
shutil
.
move
(
f
.
name
,
local_path
)
finally
:
if
f
is
not
None
:
f
.
close
()
if
os
.
path
.
exists
(
f
.
name
):
os
.
remove
(
f
.
name
)
def
load_benchmark
(
benchmark
:
str
)
->
SqliteExtDatabase
:
"""
Load a benchmark as a database.
Parmaeters
----------
benchmark : str
Benchmark name like nasbench201.
"""
if
benchmark
in
_loaded_benchmarks
:
return
_loaded_benchmarks
[
benchmark
]
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
)
_loaded_benchmarks
[
benchmark
]
=
SqliteExtDatabase
(
local_path
,
autoconnect
=
True
)
return
_loaded_benchmarks
[
benchmark
]
def
download_benchmark
(
benchmark
:
str
,
progress
:
bool
=
True
):
"""
Download a converted benchmark.
Parameters
----------
benchmark : str
Benchmark name like nasbench201.
"""
url
=
DB_URLS
[
benchmark
]
local_path
=
os
.
path
.
join
(
DATABASE_DIR
,
os
.
path
.
basename
(
url
))
load_or_download_file
(
local_path
,
url
,
True
,
progress
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment