query.py 2.85 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
6
7
8
import functools

from peewee import fn
from playhouse.shortcuts import model_to_dict
from .model import Nb101TrialStats, Nb101TrialConfig
from .graph_util import hash_module, infer_num_vertices


9
def query_nb101_trial_stats(arch, num_epochs, isomorphism=True, reduction=None, include_intermediates=False):
Yuge Zhang's avatar
Yuge Zhang committed
10
11
12
13
14
15
16
17
    """
    Query trial stats of NAS-Bench-101 given conditions.

    Parameters
    ----------
    arch : dict or None
        If a dict, it is in the format that is described in
        :class:`nni.nas.benchmark.nasbench101.Nb101TrialConfig`. Only trial stats
18
        matched will be returned. If none, all architectures in the database will be matched.
Yuge Zhang's avatar
Yuge Zhang committed
19
20
21
22
23
24
25
26
    num_epochs : int or None
        If int, matching results will be returned. Otherwise a wildcard.
    isomorphism : boolean
        Whether to match essentially-same architecture, i.e., architecture with the
        same graph-invariant hash value.
    reduction : str or None
        If 'none' or None, all trial stats will be returned directly.
        If 'mean', fields in trial stats will be averaged given the same trial config.
27
28
    include_intermediates : boolean
        If true, intermediate results will be returned.
Yuge Zhang's avatar
Yuge Zhang committed
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

    Returns
    -------
    generator of dict
        A generator of :class:`nni.nas.benchmark.nasbench101.Nb101TrialStats` objects,
        where each of them has been converted into a dict.
    """
    fields = []
    if reduction == 'none':
        reduction = None
    if reduction == 'mean':
        for field_name in Nb101TrialStats._meta.sorted_field_names:
            if field_name not in ['id', 'config']:
                fields.append(fn.AVG(getattr(Nb101TrialStats, field_name)).alias(field_name))
    elif reduction is None:
        fields.append(Nb101TrialStats)
    else:
        raise ValueError('Unsupported reduction: \'%s\'' % reduction)
    query = Nb101TrialStats.select(*fields, Nb101TrialConfig).join(Nb101TrialConfig)
    conditions = []
    if arch is not None:
        if isomorphism:
            num_vertices = infer_num_vertices(arch)
            conditions.append(Nb101TrialConfig.hash == hash_module(arch, num_vertices))
        else:
            conditions.append(Nb101TrialConfig.arch == arch)
    if num_epochs is not None:
        conditions.append(Nb101TrialConfig.num_epochs == num_epochs)
    if conditions:
        query = query.where(functools.reduce(lambda a, b: a & b, conditions))
    if reduction is not None:
        query = query.group_by(Nb101TrialStats.config)
61
62
63
64
65
66
67
68
69
70
    for trial in query:
        if include_intermediates:
            data = model_to_dict(trial)
            # exclude 'trial' from intermediates as it is already available in data
            data['intermediates'] = [
                {k: v for k, v in model_to_dict(t).items() if k != 'trial'} for t in trial.intermediates
            ]
            yield data
        else:
            yield model_to_dict(trial)