"official/vision/ops/__init__.py" did not exist on "13a5e4fb072e64f7f66d64bccef0edd0e44a996a"
Unverified Commit 717877d0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add API to query intermediate results in NAS benchmark (#2728)

parent c1844898
......@@ -34,19 +34,22 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the following architecture as an example:<br>\n",
"Use the following architecture as an example:\n",
"\n",
"![nas-101](../../img/nas-bench-101-example.png)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
"text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'intermediates': [{'current_epoch': 54,\n 'id': 19,\n 'test_acc': 77.40384340286255,\n 'train_acc': 82.82251358032227,\n 'training_time': 883.4580078125,\n 'valid_acc': 77.76442170143127},\n {'current_epoch': 108,\n 'id': 20,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 1769.1279296875,\n 'valid_acc': 92.41786599159241}],\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'intermediates': [{'current_epoch': 54,\n 'id': 21,\n 'test_acc': 82.04126358032227,\n 'train_acc': 87.96073794364929,\n 'training_time': 883.6810302734375,\n 'valid_acc': 82.91265964508057},\n {'current_epoch': 108,\n 'id': 22,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 1768.2509765625,\n 'valid_acc': 92.45793223381042}],\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'intermediates': [{'current_epoch': 54,\n 'id': 23,\n 'test_acc': 80.58894276618958,\n 'train_acc': 86.34815812110901,\n 'training_time': 883.4569702148438,\n 'valid_acc': 81.1598539352417},\n {'current_epoch': 108,\n 'id': 24,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 1768.9759521484375,\n 'valid_acc': 93.04887652397156}],\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
}
],
"source": [
......@@ -63,7 +66,7 @@
" 'input5': [0, 3, 4],\n",
" 'input6': [2, 5]\n",
"}\n",
"for t in query_nb101_trial_stats(arch, 108):\n",
"for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n",
" pprint.pprint(t)"
]
},
......@@ -85,14 +88,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Use the following architecture as an example:<br>\n",
"Use the following architecture as an example:\n",
"\n",
"![nas-201](../../img/nas-bench-201-example.png)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
......@@ -113,6 +119,32 @@
" pprint.pprint(t)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Intermediate results are also available."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 12\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n"
}
],
"source": [
"for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n",
" print(t['config'])\n",
" print('Intermediates:', len(t['intermediates']))"
]
},
{
"cell_type": "markdown",
"metadata": {},
......@@ -132,8 +164,10 @@
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"execution_count": 5,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
......@@ -156,8 +190,35 @@
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"execution_count": 6,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[{'current_epoch': 1,\n 'id': 4494501,\n 'test_acc': 41.76,\n 'train_acc': 30.421000000000006,\n 'train_loss': 1.793},\n {'current_epoch': 2,\n 'id': 4494502,\n 'test_acc': 54.66,\n 'train_acc': 47.24,\n 'train_loss': 1.415},\n {'current_epoch': 3,\n 'id': 4494503,\n 'test_acc': 59.97,\n 'train_acc': 56.983,\n 'train_loss': 1.179},\n {'current_epoch': 4,\n 'id': 4494504,\n 'test_acc': 62.91,\n 'train_acc': 61.955,\n 'train_loss': 1.048},\n {'current_epoch': 5,\n 'id': 4494505,\n 'test_acc': 66.16,\n 'train_acc': 64.493,\n 'train_loss': 0.983},\n {'current_epoch': 6,\n 'id': 4494506,\n 'test_acc': 66.5,\n 'train_acc': 66.274,\n 'train_loss': 0.937},\n {'current_epoch': 7,\n 'id': 4494507,\n 'test_acc': 67.55,\n 'train_acc': 67.426,\n 'train_loss': 0.907},\n {'current_epoch': 8,\n 'id': 4494508,\n 'test_acc': 69.45,\n 'train_acc': 68.45400000000001,\n 'train_loss': 0.878},\n {'current_epoch': 9,\n 'id': 4494509,\n 'test_acc': 70.14,\n 'train_acc': 69.295,\n 'train_loss': 0.857},\n {'current_epoch': 10,\n 'id': 4494510,\n 'test_acc': 69.47,\n 'train_acc': 70.304,\n 'train_loss': 0.832}]\n"
}
],
"source": [
"model_spec = {\n",
" 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
" 'ds': [1, 16, 1, 4],\n",
" 'num_gs': [1, 2, 1, 2],\n",
" 'ss': [1, 1, 2, 2],\n",
" 'ws': [16, 64, 128, 16]\n",
"}\n",
"for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n",
" pprint.pprint(t['intermediates'][:10])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
......@@ -173,8 +234,10 @@
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"execution_count": 8,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
......@@ -189,8 +252,10 @@
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"execution_count": 9,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
......@@ -254,8 +319,10 @@
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"execution_count": 10,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
......@@ -270,13 +337,15 @@
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"execution_count": 11,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "Elapsed time: 1.9107539653778076 seconds\n"
"text": "Elapsed time: 2.2023813724517822 seconds\n"
}
],
"source": [
......
......@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
......@@ -28,7 +29,7 @@ class Nb101TrialConfig(Model):
Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup.
"""
arch = JSONField(index=True)
arch = JSONField(json_dumps=json_dumps, index=True)
num_vertices = IntegerField(index=True)
hash = CharField(max_length=64, index=True)
num_epochs = IntegerField(index=True)
......
......@@ -6,7 +6,7 @@ from .model import Nb101TrialStats, Nb101TrialConfig
from .graph_util import hash_module, infer_num_vertices
def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, include_intermediates=False):
"""
Query trial stats of NAS-Bench-101 given conditions.
......@@ -24,6 +24,8 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns
-------
......@@ -56,5 +58,13 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(Nb101TrialStats.config)
for k in query:
yield model_to_dict(k)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
......@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True)
......@@ -35,7 +36,7 @@ class Nb201TrialConfig(Model):
for training, 6k images from validation set for validation and the other 6k for testing).
"""
arch = JSONField(index=True)
arch = JSONField(json_dumps=json_dumps, index=True)
num_epochs = IntegerField(index=True)
num_channels = IntegerField()
num_cells = IntegerField()
......
......@@ -5,7 +5,7 @@ from playhouse.shortcuts import model_to_dict
from .model import Nb201TrialStats, Nb201TrialConfig
def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
"""
Query trial stats of NAS-Bench-201 given conditions.
......@@ -23,6 +23,8 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns
-------
......@@ -53,5 +55,13 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(Nb201TrialStats.config)
for k in query:
yield model_to_dict(k)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
......@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True)
......@@ -52,8 +53,8 @@ class NdsTrialConfig(Model):
'residual_basic',
'vanilla',
])
model_spec = JSONField(index=True)
cell_spec = JSONField(index=True, null=True)
model_spec = JSONField(json_dumps=json_dumps, index=True)
cell_spec = JSONField(json_dumps=json_dumps, index=True, null=True)
dataset = CharField(max_length=15, index=True, choices=['cifar10', 'imagenet'])
generator = CharField(max_length=15, index=True, choices=[
'random',
......
......@@ -6,7 +6,7 @@ from .model import NdsTrialStats, NdsTrialConfig
def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
num_epochs=None, reduction=None):
num_epochs=None, reduction=None, include_intermediates=False):
"""
Query trial stats of NDS given conditions.
......@@ -32,6 +32,8 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
reduction : str or None
If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns
-------
......@@ -60,5 +62,13 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None:
query = query.group_by(NdsTrialStats.config)
for k in query:
yield model_to_dict(k)
for trial in query:
if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
import functools
import json
json_dumps = functools.partial(json.dumps, sort_keys=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment