Unverified Commit 131302bc authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

NAS benchmark docs update and backward-compat (#4256)

parent f13a9cd4
......@@ -15,14 +15,21 @@ Prerequisites
-------------
* Please prepare a folder to household all the benchmark databases. By default, it can be found at ``${HOME}/.nni/nasbenchmark``. You can place it anywhere you like, and specify it in ``NASBENCHMARK_DIR`` via ``export NASBENCHMARK_DIR=/path/to/your/nasbenchmark`` before importing NNI.
* Please prepare a folder to household all the benchmark databases. By default, it can be found at ``${HOME}/.cache/nni/nasbenchmark``. Or you can place it anywhere you like, and specify it in ``NASBENCHMARK_DIR`` via ``export NASBENCHMARK_DIR=/path/to/your/nasbenchmark`` before importing NNI.
* Please install ``peewee`` via ``pip3 install peewee``\ , which NNI uses to connect to database.
Data Preparation
----------------
To avoid storage and legality issues, we do not provide any prepared databases. Please follow the following steps.
Option 1 (Recommended)
^^^^^^^^^^^^^^^^^^^^^^
You can download the preprocessed benchmark files via ``python -m nni.nas.benchmarks.download <benchmark_name>``, where ``<benchmark_name>`` can be ``nasbench101``, ``nasbench201``, and etc. Add ``--help`` to the command for supported command line arguments.
Option 2
^^^^^^^^
.. note:: If you have files that are processed before v2.5, it is recommended that you delete them and try option 1.
#.
Clone NNI to your machine and enter ``examples/nas/benchmarks`` directory.
......
import os
ENV_NASBENCHMARK_DIR = 'NASBENCHMARK_DIR'
ENV_NNI_HOME = 'NNI_HOME'
ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
DEFAULT_CACHE_DIR = '~/.cache'
......@@ -10,7 +11,7 @@ def _get_nasbenchmark_dir():
nni_home = os.path.expanduser(
os.getenv(ENV_NNI_HOME,
os.path.join(os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'nni')))
return os.path.join(nni_home, 'nasbenchmark')
return os.getenv(ENV_NASBENCHMARK_DIR, os.path.join(nni_home, 'nasbenchmark'))
DATABASE_DIR = _get_nasbenchmark_dir()
......
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser('NAS benchmark downloader')
parser.add_argument('benchmark_name', choices=['nasbench101', 'nasbench201', 'nds'])
args = parser.parse_args()
from .utils import download_benchmark
download_benchmark(args.benchmark_name)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment