"git@developer.sourcefind.cn:OpenDAS/nni.git" did not exist on "1d3f4fd9138530510955f56a0dd4a79680aee5b3"
Unverified Commit 717877d0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add API to query intermediate results in NAS benchmark (#2728)

parent c1844898
...@@ -34,19 +34,22 @@ ...@@ -34,19 +34,22 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Use the following architecture as an example:<br>\n", "Use the following architecture as an example:\n",
"\n",
"![nas-101](../../img/nas-bench-101-example.png)" "![nas-101](../../img/nas-bench-101-example.png)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 2,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
"name": "stdout", "name": "stdout",
"text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n" "text": "{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 10,\n 'intermediates': [{'current_epoch': 54,\n 'id': 19,\n 'test_acc': 77.40384340286255,\n 'train_acc': 82.82251358032227,\n 'training_time': 883.4580078125,\n 'valid_acc': 77.76442170143127},\n {'current_epoch': 108,\n 'id': 20,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 1769.1279296875,\n 'valid_acc': 92.41786599159241}],\n 'parameters': 8.55553,\n 'test_acc': 92.11738705635071,\n 'train_acc': 100.0,\n 'training_time': 106147.67578125,\n 'valid_acc': 92.41786599159241}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 11,\n 'intermediates': [{'current_epoch': 54,\n 'id': 21,\n 'test_acc': 82.04126358032227,\n 'train_acc': 87.96073794364929,\n 'training_time': 883.6810302734375,\n 'valid_acc': 82.91265964508057},\n {'current_epoch': 108,\n 'id': 22,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 1768.2509765625,\n 'valid_acc': 92.45793223381042}],\n 'parameters': 8.55553,\n 'test_acc': 91.90705418586731,\n 'train_acc': 100.0,\n 'training_time': 106095.05859375,\n 'valid_acc': 92.45793223381042}\n{'config': {'arch': {'input1': [0],\n 'input2': [1],\n 'input3': [2],\n 'input4': [0],\n 'input5': [0, 3, 4],\n 'input6': [2, 5],\n 'op1': 'conv3x3-bn-relu',\n 'op2': 'maxpool3x3',\n 'op3': 'conv3x3-bn-relu',\n 'op4': 'conv3x3-bn-relu',\n 'op5': 'conv1x1-bn-relu'},\n 'hash': '00005c142e6f48ac74fdcf73e3439874',\n 'id': 4,\n 'num_epochs': 108,\n 'num_vertices': 7},\n 'id': 12,\n 'intermediates': [{'current_epoch': 54,\n 'id': 23,\n 'test_acc': 80.58894276618958,\n 'train_acc': 86.34815812110901,\n 'training_time': 883.4569702148438,\n 'valid_acc': 81.1598539352417},\n {'current_epoch': 108,\n 'id': 24,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 1768.9759521484375,\n 'valid_acc': 93.04887652397156}],\n 'parameters': 8.55553,\n 'test_acc': 92.15745329856873,\n 'train_acc': 100.0,\n 'training_time': 106138.55712890625,\n 'valid_acc': 93.04887652397156}\n"
} }
], ],
"source": [ "source": [
...@@ -63,7 +66,7 @@ ...@@ -63,7 +66,7 @@
" 'input5': [0, 3, 4],\n", " 'input5': [0, 3, 4],\n",
" 'input6': [2, 5]\n", " 'input6': [2, 5]\n",
"}\n", "}\n",
"for t in query_nb101_trial_stats(arch, 108):\n", "for t in query_nb101_trial_stats(arch, 108, include_intermediates=True):\n",
" pprint.pprint(t)" " pprint.pprint(t)"
] ]
}, },
...@@ -85,14 +88,17 @@ ...@@ -85,14 +88,17 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"Use the following architecture as an example:<br>\n", "Use the following architecture as an example:\n",
"\n",
"![nas-201](../../img/nas-bench-201-example.png)" "![nas-201](../../img/nas-bench-201-example.png)"
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": 3,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
...@@ -113,6 +119,32 @@ ...@@ -113,6 +119,32 @@
" pprint.pprint(t)" " pprint.pprint(t)"
] ]
}, },
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Intermediate results are also available."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "{'id': 4, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 12, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 12\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n{'id': 8, 'arch': {'0_1': 'avg_pool_3x3', '0_2': 'conv_1x1', '0_3': 'conv_1x1', '1_2': 'skip_connect', '1_3': 'skip_connect', '2_3': 'skip_connect'}, 'num_epochs': 200, 'num_channels': 16, 'num_cells': 5, 'dataset': 'imagenet16-120'}\nIntermediates: 200\n"
}
],
"source": [
"for t in query_nb201_trial_stats(arch, None, 'imagenet16-120', include_intermediates=True):\n",
" print(t['config'])\n",
" print('Intermediates:', len(t['intermediates']))"
]
},
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
...@@ -132,8 +164,10 @@ ...@@ -132,8 +164,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 4, "execution_count": 5,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
...@@ -156,8 +190,35 @@ ...@@ -156,8 +190,35 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 6,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": "[{'current_epoch': 1,\n 'id': 4494501,\n 'test_acc': 41.76,\n 'train_acc': 30.421000000000006,\n 'train_loss': 1.793},\n {'current_epoch': 2,\n 'id': 4494502,\n 'test_acc': 54.66,\n 'train_acc': 47.24,\n 'train_loss': 1.415},\n {'current_epoch': 3,\n 'id': 4494503,\n 'test_acc': 59.97,\n 'train_acc': 56.983,\n 'train_loss': 1.179},\n {'current_epoch': 4,\n 'id': 4494504,\n 'test_acc': 62.91,\n 'train_acc': 61.955,\n 'train_loss': 1.048},\n {'current_epoch': 5,\n 'id': 4494505,\n 'test_acc': 66.16,\n 'train_acc': 64.493,\n 'train_loss': 0.983},\n {'current_epoch': 6,\n 'id': 4494506,\n 'test_acc': 66.5,\n 'train_acc': 66.274,\n 'train_loss': 0.937},\n {'current_epoch': 7,\n 'id': 4494507,\n 'test_acc': 67.55,\n 'train_acc': 67.426,\n 'train_loss': 0.907},\n {'current_epoch': 8,\n 'id': 4494508,\n 'test_acc': 69.45,\n 'train_acc': 68.45400000000001,\n 'train_loss': 0.878},\n {'current_epoch': 9,\n 'id': 4494509,\n 'test_acc': 70.14,\n 'train_acc': 69.295,\n 'train_loss': 0.857},\n {'current_epoch': 10,\n 'id': 4494510,\n 'test_acc': 69.47,\n 'train_acc': 70.304,\n 'train_loss': 0.832}]\n"
}
],
"source": [
"model_spec = {\n",
" 'bot_muls': [0.0, 0.25, 0.25, 0.25],\n",
" 'ds': [1, 16, 1, 4],\n",
" 'num_gs': [1, 2, 1, 2],\n",
" 'ss': [1, 1, 2, 2],\n",
" 'ws': [16, 64, 128, 16]\n",
"}\n",
"for t in query_nds_trial_stats('residual_bottleneck', None, None, model_spec, None, 'cifar10', include_intermediates=True):\n",
" pprint.pprint(t['intermediates'][:10])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
...@@ -173,8 +234,10 @@ ...@@ -173,8 +234,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 6, "execution_count": 8,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
...@@ -189,8 +252,10 @@ ...@@ -189,8 +252,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 7, "execution_count": 9,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
...@@ -254,8 +319,10 @@ ...@@ -254,8 +319,10 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 8, "execution_count": 10,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
...@@ -270,13 +337,15 @@ ...@@ -270,13 +337,15 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 9, "execution_count": 11,
"metadata": {}, "metadata": {
"tags": []
},
"outputs": [ "outputs": [
{ {
"output_type": "stream", "output_type": "stream",
"name": "stdout", "name": "stdout",
"text": "Elapsed time: 1.9107539653778076 seconds\n" "text": "Elapsed time: 2.2023813724517822 seconds\n"
} }
], ],
"source": [ "source": [
......
...@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model ...@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True) db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench101.db'), autoconnect=True)
...@@ -28,7 +29,7 @@ class Nb101TrialConfig(Model): ...@@ -28,7 +29,7 @@ class Nb101TrialConfig(Model):
Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup. Number of epochs planned for this trial. Should be one of 4, 12, 36, 108 in default setup.
""" """
arch = JSONField(index=True) arch = JSONField(json_dumps=json_dumps, index=True)
num_vertices = IntegerField(index=True) num_vertices = IntegerField(index=True)
hash = CharField(max_length=64, index=True) hash = CharField(max_length=64, index=True)
num_epochs = IntegerField(index=True) num_epochs = IntegerField(index=True)
......
...@@ -6,7 +6,7 @@ from .model import Nb101TrialStats, Nb101TrialConfig ...@@ -6,7 +6,7 @@ from .model import Nb101TrialStats, Nb101TrialConfig
from .graph_util import hash_module, infer_num_vertices from .graph_util import hash_module, infer_num_vertices
def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, include_intermediates=False):
""" """
Query trial stats of NAS-Bench-101 given conditions. Query trial stats of NAS-Bench-101 given conditions.
...@@ -24,6 +24,8 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): ...@@ -24,6 +24,8 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
reduction : str or None reduction : str or None
If 'none' or None, all trial stats will be returned directly. If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config. If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns Returns
------- -------
...@@ -56,5 +58,13 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None): ...@@ -56,5 +58,13 @@ def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions)) query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None: if reduction is not None:
query = query.group_by(Nb101TrialStats.config) query = query.group_by(Nb101TrialStats.config)
for k in query: for trial in query:
yield model_to_dict(k) if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
...@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model ...@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True) db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nasbench201.db'), autoconnect=True)
...@@ -35,7 +36,7 @@ class Nb201TrialConfig(Model): ...@@ -35,7 +36,7 @@ class Nb201TrialConfig(Model):
for training, 6k images from validation set for validation and the other 6k for testing). for training, 6k images from validation set for validation and the other 6k for testing).
""" """
arch = JSONField(index=True) arch = JSONField(json_dumps=json_dumps, index=True)
num_epochs = IntegerField(index=True) num_epochs = IntegerField(index=True)
num_channels = IntegerField() num_channels = IntegerField()
num_cells = IntegerField() num_cells = IntegerField()
......
...@@ -5,7 +5,7 @@ from playhouse.shortcuts import model_to_dict ...@@ -5,7 +5,7 @@ from playhouse.shortcuts import model_to_dict
from .model import Nb201TrialStats, Nb201TrialConfig from .model import Nb201TrialStats, Nb201TrialConfig
def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None): def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None, include_intermediates=False):
""" """
Query trial stats of NAS-Bench-201 given conditions. Query trial stats of NAS-Bench-201 given conditions.
...@@ -23,6 +23,8 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None): ...@@ -23,6 +23,8 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
reduction : str or None reduction : str or None
If 'none' or None, all trial stats will be returned directly. If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config. If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns Returns
------- -------
...@@ -53,5 +55,13 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None): ...@@ -53,5 +55,13 @@ def query_nb201_trial_stats(arch, num_epochs, dataset, reduction=None):
query = query.where(functools.reduce(lambda a, b: a & b, conditions)) query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None: if reduction is not None:
query = query.group_by(Nb201TrialStats.config) query = query.group_by(Nb201TrialStats.config)
for k in query: for trial in query:
yield model_to_dict(k) if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
...@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model ...@@ -4,6 +4,7 @@ from peewee import CharField, FloatField, ForeignKeyField, IntegerField, Model
from playhouse.sqlite_ext import JSONField, SqliteExtDatabase from playhouse.sqlite_ext import JSONField, SqliteExtDatabase
from nni.nas.benchmarks.constants import DATABASE_DIR from nni.nas.benchmarks.constants import DATABASE_DIR
from nni.nas.benchmarks.utils import json_dumps
db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True) db = SqliteExtDatabase(os.path.join(DATABASE_DIR, 'nds.db'), autoconnect=True)
...@@ -52,8 +53,8 @@ class NdsTrialConfig(Model): ...@@ -52,8 +53,8 @@ class NdsTrialConfig(Model):
'residual_basic', 'residual_basic',
'vanilla', 'vanilla',
]) ])
model_spec = JSONField(index=True) model_spec = JSONField(json_dumps=json_dumps, index=True)
cell_spec = JSONField(index=True, null=True) cell_spec = JSONField(json_dumps=json_dumps, index=True, null=True)
dataset = CharField(max_length=15, index=True, choices=['cifar10', 'imagenet']) dataset = CharField(max_length=15, index=True, choices=['cifar10', 'imagenet'])
generator = CharField(max_length=15, index=True, choices=[ generator = CharField(max_length=15, index=True, choices=[
'random', 'random',
......
...@@ -6,7 +6,7 @@ from .model import NdsTrialStats, NdsTrialConfig ...@@ -6,7 +6,7 @@ from .model import NdsTrialStats, NdsTrialConfig
def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset, def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_spec, dataset,
num_epochs=None, reduction=None): num_epochs=None, reduction=None, include_intermediates=False):
""" """
Query trial stats of NDS given conditions. Query trial stats of NDS given conditions.
...@@ -32,6 +32,8 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp ...@@ -32,6 +32,8 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
reduction : str or None reduction : str or None
If 'none' or None, all trial stats will be returned directly. If 'none' or None, all trial stats will be returned directly.
If 'mean', fields in trial stats will be averaged given the same trial config. If 'mean', fields in trial stats will be averaged given the same trial config.
include_intermediates : boolean
If true, intermediate results will be returned.
Returns Returns
------- -------
...@@ -60,5 +62,13 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp ...@@ -60,5 +62,13 @@ def query_nds_trial_stats(model_family, proposer, generator, model_spec, cell_sp
query = query.where(functools.reduce(lambda a, b: a & b, conditions)) query = query.where(functools.reduce(lambda a, b: a & b, conditions))
if reduction is not None: if reduction is not None:
query = query.group_by(NdsTrialStats.config) query = query.group_by(NdsTrialStats.config)
for k in query: for trial in query:
yield model_to_dict(k) if include_intermediates:
data = model_to_dict(trial)
# exclude 'trial' from intermediates as it is already available in data
data['intermediates'] = [
{k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
]
yield data
else:
yield model_to_dict(trial)
import functools
import json
json_dumps = functools.partial(json.dumps, sort_keys=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment