{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# NAS 基准测试示例" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import pprint\n", "import time\n", "\n", "from nni.nas.benchmarks.nasbench101 import query_nb101_trial_stats\n", "from nni.nas.benchmarks.nasbench201 import query_nb201_trial_stats\n", "from nni.nas.benchmarks.nds import query_nds_trial_stats\n", "\n", "ti = time.time()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NAS-Bench-101" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用以下架构为例:\n", "\n", "![nas-101](../../img/nas-bench-101-example.png)" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "tags": [] }, "outputs": [], "source": [ "arch = {\n", " 'op1': 'conv3x3-bn-relu',\n", " 'op2': 'maxpool3x3',\n", " 'op3': 'conv3x3-bn-relu',\n", " 'op4': 'conv3x3-bn-relu',\n", " 'op5': 'conv1x1-bn-relu',\n", " 'input1': [0],\n", " 'input2': [1],\n", " 'input3': [2],\n", " 'input4': [0],\n", " 'input5': [0, 3, 4],\n", " 'input6': [2, 5]\n", "}\n", "for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n", " pprint.pprint(t)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "一个 NAS-Bench-101 的网络结构可以被训练多次。 生成器返回的每一个元素是一个字典,包含了该 Trial 设置(网络结构+超参数)中其中一个训练结果,如训练集/验证集/测试集准确率,训练时间,Epoch数等等。 NAS-Bench-201 和 NDS 的结果遵循了相似的格式。" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NAS-Bench-201" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用以下架构为例:\n", "\n", "![nas-201](../../img/nas-bench-201-example.png)" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "tags": [] }, "outputs": [], "source": [ "arch = {\n", " '0_1': 'avg_pool_3x3',\n", " '0_2': 'conv_1x1',\n", " '1_2': 'skip_connect',\n", " '0_3': 'conv_1x1',\n", " '1_3': 'skip_connect',\n", " '2_3': 'skip_connect'\n", "}\n", "for t in query_nb201_trial_stats(arch, 200, 'cifar100'):\n", " pprint.pprint(t)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "中间结果也可得到。" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "tags": [] }, "outputs": [], "source": [ "for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n", " print(t['config'])\n", " print('Intermediates:', len(t['intermediates']))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NDS" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "使用以下架构为例:
\n", "![nds](../../img/nas-bench-nds-example.png)\n", "\n", "这里,`bot_muls`, `ds`, `num_gs`, `ss` 和 `ws` 分别表示 \"bottleneck multipliers\", \"depths\", \"number of groups\", \"strides\" 和 \"widths\" 。" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_spec = {\n", " 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n", " 'ds': [1, 16, 1, 4],\n", " 'num_gs': [1, 2, 1, 2],\n", " 'ss': [1, 1, 2, 2],\n", " 'ws': [16, 64, 128, 16]\n", "}\n", "# Use none as a wildcard\n", "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10'):\n", " pprint.pprint(t)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_spec = {\n", " 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n", " 'ds': [1, 16, 1, 4],\n", " 'num_gs': [1, 2, 1, 2],\n", " 'ss': [1, 1, 2, 2],\n", " 'ws': [16, 64, 128, 16]\n", "}\n", "for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n", " pprint.pprint(t['intermediates'][:10])" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "tags": [] }, "outputs": [], "source": [ "model_spec = {'ds': [1, 12, 12, 12], 'ss': [1, 1, 2, 2], 'ws': [16, 24, 24, 40]}\n", "for t in query_nds_trial_stats('residual_basic', 'resnet', 'random', model_spec, {}, 'cifar10'):\n", " pprint.pprint(t)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "tags": [] }, "outputs": [], "source": [ "# get the first one\n", "pprint.pprint(next(query_nds_trial_stats('vanilla', None, None, None, None, None)))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "tags": [] }, "outputs": [], "source": [ "# count number\n", "model_spec = {'num_nodes_normal': 5, 'num_nodes_reduce': 5, 'depth': 12, 'width': 32, 'aux': False, 'drop_prob': 0.0}\n", "cell_spec = {\n", " 'normal_0_op_x': 'avg_pool_3x3',\n", " 'normal_0_input_x': 0,\n", " 'normal_0_op_y': 'conv_7x1_1x7',\n", " 'normal_0_input_y': 1,\n", " 'normal_1_op_x': 'sep_conv_3x3',\n", " 'normal_1_input_x': 2,\n", " 'normal_1_op_y': 'sep_conv_5x5',\n", " 'normal_1_input_y': 0,\n", " 'normal_2_op_x': 'dil_sep_conv_3x3',\n", " 'normal_2_input_x': 2,\n", " 'normal_2_op_y': 'dil_sep_conv_3x3',\n", " 'normal_2_input_y': 2,\n", " 'normal_3_op_x': 'skip_connect',\n", " 'normal_3_input_x': 4,\n", " 'normal_3_op_y': 'dil_sep_conv_3x3',\n", " 'normal_3_input_y': 4,\n", " 'normal_4_op_x': 'conv_7x1_1x7',\n", " 'normal_4_input_x': 2,\n", " 'normal_4_op_y': 'sep_conv_3x3',\n", " 'normal_4_input_y': 4,\n", " 'normal_concat': [3, 5, 6],\n", " 'reduce_0_op_x': 'avg_pool_3x3',\n", " 'reduce_0_input_x': 0,\n", " 'reduce_0_op_y': 'dil_sep_conv_3x3',\n", " 'reduce_0_input_y': 1,\n", " 'reduce_1_op_x': 'sep_conv_3x3',\n", " 'reduce_1_input_x': 0,\n", " 'reduce_1_op_y': 'sep_conv_3x3',\n", " 'reduce_1_input_y': 0,\n", " 'reduce_2_op_x': 'skip_connect',\n", " 'reduce_2_input_x': 2,\n", " 'reduce_2_op_y': 'sep_conv_7x7',\n", " 'reduce_2_input_y': 0,\n", " 'reduce_3_op_x': 'conv_7x1_1x7',\n", " 'reduce_3_input_x': 4,\n", " 'reduce_3_op_y': 'skip_connect',\n", " 'reduce_3_input_y': 4,\n", " 'reduce_4_op_x': 'conv_7x1_1x7',\n", " 'reduce_4_input_x': 0,\n", " 'reduce_4_op_y': 'conv_7x1_1x7',\n", " 'reduce_4_input_y': 5,\n", " 'reduce_concat': [3, 6]\n", "}\n", "\n", "for t in query_nds_trial_stats('nas_cell', None, None, model_spec, cell_spec, 'cifar10'):\n", " assert t['config']['model_spec'] == model_spec\n", " assert t['config']['cell_spec'] == cell_spec\n", " pprint.pprint(t)" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "tags": [] }, "outputs": [], "source": [ "# count number\n", "print('NDS (amoeba) count:', len(list(query_nds_trial_stats(None, 'amoeba', None, None, None, None, None))))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## NLP" ] }, { "cell_type": "markdown", "metadata": { "pycharm": { "metadata": false } }, "source": [ "使用以下两种结构作为示例。 \n", "论文中的 arch 被称为嵌套变量的 “receipe”,目前还没有在 NNI 的基准测试中使用。\n", "一个架构有多个节点,节点输入和节点操作,您可以参考文档了解更多详细信息。\n", "\n", "arch1 : \n", "\n", "\n", "arch2 : \n", "\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "{'config': {'arch': {'h_new_0_input_0': 'node_3',\n 'h_new_0_input_1': 'node_2',\n 'h_new_0_input_2': 'node_1',\n 'h_new_0_op': 'blend',\n 'node_0_input_0': 'x',\n 'node_0_input_1': 'h_prev_0',\n 'node_0_op': 'linear',\n 'node_1_input_0': 'node_0',\n 'node_1_op': 'activation_tanh',\n 'node_2_input_0': 'h_prev_0',\n 'node_2_input_1': 'node_1',\n 'node_2_input_2': 'x',\n 'node_2_op': 'linear',\n 'node_3_input_0': 'node_2',\n 'node_3_op': 'activation_leaky_relu'},\n 'dataset': 'ptb',\n 'id': 20003},\n 'id': 16291,\n 'test_loss': 4.680262297102549,\n 'train_loss': 4.132040537087838,\n 'training_time': 177.05208373069763,\n 'val_loss': 4.707944253177966}\n" ] } ], "source": [ "import pprint\n", "from nni.nas.benchmarks.nlp import query_nlp_trial_stats\n", "\n", "arch1 = {'h_new_0_input_0': 'node_3', 'h_new_0_input_1': 'node_2', 'h_new_0_input_2': 'node_1', 'h_new_0_op': 'blend', 'node_0_input_0': 'x', 'node_0_input_1': 'h_prev_0', 'node_0_op': 'linear','node_1_input_0': 'node_0', 'node_1_op': 'activation_tanh', 'node_2_input_0': 'h_prev_0', 'node_2_input_1': 'node_1', 'node_2_input_2': 'x', 'node_2_op': 'linear', 'node_3_input_0': 'node_2', 'node_3_op': 'activation_leaky_relu'}\n", "for i in query_nlp_trial_stats(arch=arch1, dataset=\"ptb\"):\n", " pprint.pprint(i)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "[{'current_epoch': 46,\n 'id': 1796,\n 'test_loss': 6.233430054978619,\n 'train_loss': 6.4866799231542664,\n 'training_time': 146.5680329799652,\n 'val_loss': 6.326836978687959},\n {'current_epoch': 47,\n 'id': 1797,\n 'test_loss': 6.2402057403023825,\n 'train_loss': 6.485401405247535,\n 'training_time': 146.05511450767517,\n 'val_loss': 6.3239741605870865},\n {'current_epoch': 48,\n 'id': 1798,\n 'test_loss': 6.351145308363877,\n 'train_loss': 6.611281181173992,\n 'training_time': 145.8849437236786,\n 'val_loss': 6.436160816865809},\n {'current_epoch': 49,\n 'id': 1799,\n 'test_loss': 6.227155079159031,\n 'train_loss': 6.473414458249545,\n 'training_time': 145.51414465904236,\n 'val_loss': 6.313294354607077}]\n" ] } ], "source": [ "arch2 = {\"h_new_0_input_0\":\"node_0\",\"h_new_0_input_1\":\"node_1\",\"h_new_0_op\":\"elementwise_sum\",\"node_0_input_0\":\"x\",\"node_0_input_1\":\"h_prev_0\",\"node_0_op\":\"linear\",\"node_1_input_0\":\"node_0\",\"node_1_op\":\"activation_tanh\"}\n", "for i in query_nlp_trial_stats(arch=arch2, dataset='wikitext-2', include_intermediates=True):\n", " pprint.pprint(i['intermediates'][45:49])" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "pycharm": {}, "tags": [] }, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "Elapsed time: 5.60982608795166 seconds\n" ] } ], "source": [ "print('Elapsed time: ', time.time() - ti, 'seconds')" ] } ], "metadata": { "file_extension": ".py", "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "name": "python", "version": "3.8.5-final" }, "mimetype": "text/x-python", "name": "python", "npconvert_exporter": "python", "orig_nbformat": 2, "pygments_lexer": "ipython3", "version": 3 }, "nbformat": 4, "nbformat_minor": 2 }