db_gen.py 2.45 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
4
5
import argparse

from tqdm import tqdm
from nasbench import api  # pylint: disable=import-error

6
7
from nni.nas.benchmarks.utils import load_benchmark
from .model import Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats
Yuge Zhang's avatar
Yuge Zhang committed
8
9
10
11
12
13
14
15
16
from .graph_util import nasbench_format_to_architecture_repr, hash_module


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('input_file',
                        help='Path to the file to be converted, e.g., nasbench_full.tfrecord')
    args = parser.parse_args()
    nasbench = api.NASBench(args.input_file)
17
18

    db = load_benchmark('nasbench101')
Yuge Zhang's avatar
Yuge Zhang committed
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
    with db:
        db.create_tables([Nb101TrialConfig, Nb101TrialStats, Nb101IntermediateStats])
        for hashval in tqdm(nasbench.hash_iterator(), desc='Dumping data into database'):
            metadata, metrics = nasbench.get_metrics_from_hash(hashval)
            num_vertices, architecture = nasbench_format_to_architecture_repr(
                metadata['module_adjacency'], metadata['module_operations'])
            assert hashval == hash_module(architecture, num_vertices)
            for epochs in [4, 12, 36, 108]:
                trial_config = Nb101TrialConfig.create(
                    arch=architecture,
                    num_vertices=num_vertices,
                    hash=hashval,
                    num_epochs=epochs
                )

                for seed in range(3):
                    cur = metrics[epochs][seed]
                    trial = Nb101TrialStats.create(
                        config=trial_config,
                        train_acc=cur['final_train_accuracy'] * 100,
                        valid_acc=cur['final_validation_accuracy'] * 100,
                        test_acc=cur['final_test_accuracy'] * 100,
                        parameters=metadata['trainable_parameters'] / 1e6,
                        training_time=cur['final_training_time'] * 60
                    )
                    for t in ['halfway', 'final']:
                        Nb101IntermediateStats.create(
                            trial=trial,
                            current_epoch=epochs // 2 if t == 'halfway' else epochs,
                            training_time=cur[t + '_training_time'],
                            train_acc=cur[t + '_train_accuracy'] * 100,
                            valid_acc=cur[t + '_validation_accuracy'] * 100,
                            test_acc=cur[t + '_test_accuracy'] * 100
                        )


if __name__ == '__main__':
    main()