"...text-generation-inference.git" did not exist on "2b19d671b4d1020e31276477f278ca87cfa37a3c"
Unverified Commit afe42cea authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

NAS benchmark integration (stage 2) - download (#4205)

parent 000de04b
from .utils import load_benchmark, download_benchmark
import os import os
# TODO: need to be refactored to support automatic download ENV_NNI_HOME = 'NNI_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
DATABASE_DIR = os.environ.get("NASBENCHMARK_DIR", os.path.expanduser("~/.nni/nasbenchmark"))
def _get_nasbenchmark_dir():
nni_home = os.path.expanduser(
os.getenv(ENV_NNI_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'nni')))
return os.path.join(nni_home, 'nasbenchmark')
DATABASE_DIR = _get_nasbenchmark_dir()
DB_URLS = {
'nasbench101': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench101-209f5694.db',
'nasbench201': 'https://nni.blob.core.windows.net/nasbenchmark/nasbench201-b2b60732.db',
'nds': 'https://nni.blob.core.windows.net/nasbenchmark/nds-5745c235.db'
}
...@@ -3,7 +3,8 @@ import argparse ...@@ -3,7 +3,8 @@ import argparse
from tqdm import tqdm from tqdm import tqdm
from nasbench import api # pylint: disable=import-error from nasbench import api # pylint: disable=import-error
from .model import db, Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
from .graph_util import nasbench_format_to_architecture_repr, hash_module from .graph_util import nasbench_format_to_architecture_repr, hash_module
...@@ -13,6 +14,8 @@ def main(): ...@@ -13,6 +14,8 @@ def main():
help='Path to the file to be converted, e.g., nasbench_full.tfrecord') help='Path to the file to be converted, e.g., nasbench_full.tfrecord')
args = parser.parse_args() args = parser.parse_args()
nasbench = api.NASBench(args.input_file) nasbench = api.NASBench(args.input_file)
db = load_benchmark('nasbench101')
with db: with db:
db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats]) db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats])
for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'): for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'):
......
import os from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True) proxy = Proxy()
class Nb101TrialConfig(Model): class Nb101TrialConfig(Model):
...@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model): ...@@ -35,7 +32,7 @@ class Nb101TrialConfig(Model):
num_epochs = IntegerField(index=True) num_epochs = IntegerField(index=True)
class Meta: class Meta:
database = db database = proxy
class Nb101TrialStats(Model): class Nb101TrialStats(Model):
...@@ -68,7 +65,7 @@ class Nb101TrialStats(Model): ...@@ -68,7 +65,7 @@ class Nb101TrialStats(Model):
training_time = FloatField() training_time = FloatField()
class Meta: class Meta:
database = db database = proxy
class Nb101IntermediateStats(Model): class Nb101IntermediateStats(Model):
...@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model): ...@@ -99,4 +96,4 @@ class Nb101IntermediateStats(Model):
training_time = FloatField() training_time = FloatField()
class Meta: class Meta:
database = db database = proxy
...@@ -2,7 +2,9 @@ import functools ...@@ -2,7 +2,9 @@ import functools
from peewee import fn from peewee import fn
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from .model import Nb101TrialStats, Nb101TrialConfig
from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialStats, Nb101TrialConfig, proxy
from .graph_util import hash_module, infer_num_vertices from .graph_util import hash_module, infer_num_vertices
...@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, ...@@ -33,6 +35,10 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None,
A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects, A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
where each of them has been converted into a dict. where each of them has been converted into a dict.
""" """
if proxy.obj is None:
proxy.initialize(load_benchmark('nasbench101'))
fields = [] fields = []
if reduction == 'none': if reduction == 'none':
reduction = None reduction = None
......
...@@ -4,8 +4,9 @@ import re ...@@ -4,8 +4,9 @@ import re
import tqdm import tqdm
import torch import torch
from nni.nas.benchmarks.utils import load_benchmark
from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3 from .constants import NONE, SKIP_CONNECT, CONV_1X1, CONV_3X3, AVG_POOL_3X3
from .model import db, Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats from .model import Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats
def parse_arch_str(arch_str): def parse_arch_str(arch_str):
...@@ -39,6 +40,8 @@ def main(): ...@@ -39,6 +40,8 @@ def main():
'imagenet16-120': ['train', 'x-valid', 'x-test', 'ori-test'], 'imagenet16-120': ['train', 'x-valid', 'x-test', 'ori-test'],
} }
db = load_benchmark('nasbench201')
with db: with db:
db.create_tables([Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats]) db.create_tables([Nb201TrialConfig, Nb201TrialStats, Nb201IntermediateStats])
print('Loading NAS-Bench-201 pickle...') print('Loading NAS-Bench-201 pickle...')
......
import os from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True) proxy = Proxy()
class Nb201TrialConfig(Model): class Nb201TrialConfig(Model):
...@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model): ...@@ -48,7 +45,7 @@ class Nb201TrialConfig(Model):
]) ])
class Meta: class Meta:
database = db database = proxy
class Nb201TrialStats(Model): class Nb201TrialStats(Model):
...@@ -113,7 +110,7 @@ class Nb201TrialStats(Model): ...@@ -113,7 +110,7 @@ class Nb201TrialStats(Model):
ori_test_evaluation_time = FloatField() ori_test_evaluation_time = FloatField()
class Meta: class Meta:
database = db database = proxy
class Nb201IntermediateStats(Model): class Nb201IntermediateStats(Model):
...@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model): ...@@ -157,4 +154,4 @@ class Nb201IntermediateStats(Model):
ori_test_loss = FloatField(null=True) ori_test_loss = FloatField(null=True)
class Meta: class Meta:
database = db database = proxy
...@@ -2,7 +2,9 @@ import functools ...@@ -2,7 +2,9 @@ import functools
from peewee import fn from peewee import fn
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from .model import Nb201TrialStats, Nb201TrialConfig
from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb201TrialStats, Nb201TrialConfig, proxy
def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False): def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
...@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i ...@@ -32,6 +34,10 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_i
A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects, A generator of :class:`nni.nas.benchmark.nasbench201.Nb201TrialStats` objects,
where each of them has been converted into a dict. where each of them has been converted into a dict.
""" """
if proxy.obj is None:
proxy.initialize(load_benchmark('nasbench201'))
fields = [] fields = []
if reduction == 'none': if reduction == 'none':
reduction = None reduction = None
......
...@@ -5,7 +5,8 @@ import os ...@@ -5,7 +5,8 @@ import os
import numpy as np import numpy as np
import tqdm import tqdm
from .model import db, NdsTrialConfig, NdsTrialStats, NdsIntermediateStats from nni.nas.benchmarks.utils import load_benchmark
from .model import NdsTrialConfig, NdsTrialStats, NdsIntermediateStats
def inject_item(db, item, proposer, dataset, generator): def inject_item(db, item, proposer, dataset, generator):
...@@ -120,6 +121,8 @@ def main(): ...@@ -120,6 +121,8 @@ def main():
'Vanilla_rng3.json' 'Vanilla_rng3.json'
] ]
db = load_benchmark('nds')
with db: with db:
db.create_tables([NdsTrialConfig, NdsTrialStats, NdsIntermediateStats]) db.create_tables([NdsTrialConfig, NdsTrialStats, NdsIntermediateStats])
for json_idx, json_file in enumerate(sweep_list, start=1): for json_idx, json_file in enumerate(sweep_list, start=1):
......
import os from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model, Proxy
from playhouse.sqlite_ext import JSONField
from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True) proxy = Proxy()
class NdsTrialConfig(Model): class NdsTrialConfig(Model):
...@@ -67,7 +64,7 @@ class NdsTrialConfig(Model): ...@@ -67,7 +64,7 @@ class NdsTrialConfig(Model):
num_epochs = IntegerField() num_epochs = IntegerField()
class Meta: class Meta:
database = db database = proxy
class NdsTrialStats(Model): class NdsTrialStats(Model):
...@@ -112,7 +109,7 @@ class NdsTrialStats(Model): ...@@ -112,7 +109,7 @@ class NdsTrialStats(Model):
iter_time = FloatField() iter_time = FloatField()
class Meta: class Meta:
database = db database = proxy
class NdsIntermediateStats(Model): class NdsIntermediateStats(Model):
...@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model): ...@@ -140,4 +137,4 @@ class NdsIntermediateStats(Model):
test_acc = FloatField() test_acc = FloatField()
class Meta: class Meta:
database = db database = proxy
...@@ -2,7 +2,9 @@ import functools ...@@ -2,7 +2,9 @@ import functools
from peewee import fn from peewee import fn
from playhouse.shortcuts import model_to_dict from playhouse.shortcuts import model_to_dict
from .model import NdsTrialStats, NdsTrialConfig
from nni.nas.benchmarks.utils import load_benchmark
from .model import NdsTrialStats, NdsTrialConfig, proxy
def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset, def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
...@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp ...@@ -41,6 +43,10 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects, A generator of :class:`nni.nas.benchmark.nds.NdsTrialStats` objects,
where each of them has been converted into a dict. where each of them has been converted into a dict.
""" """
if proxy.obj is None:
proxy.initialize(load_benchmark('nds'))
fields = [] fields = []
if reduction == 'none': if reduction == 'none':
reduction = None reduction = None
......
import functools import functools
import hashlib
import json import json
import logging
import os
import shutil
import tempfile
from pathlib import Path
import requests
import tqdm
from playhouse.sqlite_ext import SqliteExtDatabase
from .constants import DB_URLS, DATABASE_DIR
json_dumps = functools.partial(json.dumps, sort_keys=True) json_dumps = functools.partial(json.dumps, sort_keys=True)
# to prevent repetitive loading of benchmarks
_loaded_benchmarks = {}
def load_or_download_file(local_path: str, download_url: str, download: bool = False, progress: bool = True):
f = None
hash_prefix = Path(local_path).stem.split('-')[-1]
_logger = logging.getLogger(__name__)
try:
sha256 = hashlib.sha256()
if Path(local_path).exists():
_logger.info('"%s" already exists. Checking hash.', local_path)
with Path(local_path).open('rb') as fr:
while True:
chunk = fr.read(8192)
if len(chunk) == 0:
break
sha256.update(chunk)
elif download:
_logger.info('"%s" does not exist. Downloading "%s"', local_path, download_url)
# Follow download implementation in torchvision:
# We deliberately save it in a temp file and move it after
# download is complete. This prevents a local working checkpoint
# being overridden by a broken download.
dst_dir = Path(local_path).parent
dst_dir.mkdir(exist_ok=True, parents=True)
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
r = requests.get(download_url, stream=True)
total_length = int(r.headers.get('content-length'))
with tqdm.tqdm(total=total_length, disable=not progress,
unit='B', unit_scale=True, unit_divisor=1024) as pbar:
for chunk in r.iter_content(8192):
f.write(chunk)
sha256.update(chunk)
pbar.update(len(chunk))
f.flush()
else:
raise FileNotFoundError('Download is not enabled, but file still does not exist: {}'.format(local_path))
digest = sha256.hexdigest()
if not digest.startswith(hash_prefix):
raise RuntimeError('Invalid hash value (expected "{}", got "{}")'.format(hash_prefix, digest))
if f is not None:
shutil.move(f.name, local_path)
finally:
if f is not None:
f.close()
if os.path.exists(f.name):
os.remove(f.name)
def load_benchmark(benchmark: str) -> SqliteExtDatabase:
"""
Load a benchmark as a database.
Parmaeters
----------
benchmark : str
Benchmark name like nasbench201.
"""
if benchmark in _loaded_benchmarks:
return _loaded_benchmarks[benchmark]
url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url)
_loaded_benchmarks[benchmark] = SqliteExtDatabase(local_path, autoconnect=True)
return _loaded_benchmarks[benchmark]
def download_benchmark(benchmark: str, progress: bool = True):
"""
Download a converted benchmark.
Parameters
----------
benchmark : str
Benchmark name like nasbench201.
"""
url = DB_URLS[benchmark]
local_path = os.path.join(DATABASE_DIR, os.path.basename(url))
load_or_download_file(local_path, url, True, progress)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment