utils.py 2.92 KB
Newer Older
huchen's avatar
huchen committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
from genotype.protoc_pb2 import Genotype, Cell
from google.protobuf import text_format
from collections import namedtuple


__all__ = [
    'convert_legacy_format_to_protobuf',
    'load_genotype_from_file',
    'save_genotype_to_file',
]


LegacyGenotype = namedtuple('LegacyGenotype', 'normal normal_concat reduce reduce_concat')

PRIMITIVES = [
    'none',
    'max_pool_3x3',
    'avg_pool_3x3',
    'skip_connect',
    'sep_conv_3x3',
    'sep_conv_5x5',
    'dil_conv_3x3',
    'dil_conv_5x5'
]


def convert_legacy_format_to_protobuf(legacy_genotype, init_channel=48, layers=14):
    """
    Convert old format NASNET genotype to a protobuf format object
    Args:
        legacy_genotype(LegacyGenotype): a named tuple containing normal, normal_concat, reduce, and reduce_concat
        init_channel: init channel count of the network, default 48
        layers: number of cells of the network, default 14

    Returns:
        A protobuf genotype object defining the architecture
    """
    g = Genotype()
    g.init_channel = init_channel

    C = init_channel
    for i in range(layers):
        if i in [layers // 3, 2 * layers // 3]:
            reduction = True
            C *= 2
        else:
            reduction = False
        if i == 2 * layers // 3:
            auxiliary = True
        else:
            auxiliary = False
        ops = []
        for j, (op_name, indice) in enumerate(legacy_genotype.reduce if reduction else legacy_genotype.normal):
            ops.append({'frm': indice, 'to': j // 2 + 2, 'type': PRIMITIVES.index(op_name)})

        g.cell.append(Cell(
            id=i,
            type=reduction,
            channel=C,
            num_steps=4,
            op=ops,
            concat=[2, 3, 4, 5],
            auxiliary=auxiliary,
        ))

    return g


def load_genotype_from_file(file):
    """
    Load Genotype object from text file
    Args:
        file: text file defining the architecture following protobuf format

    Returns:
        a Genotype object
    """
    g = Genotype()
    text_format.Parse(open(file, 'r').read(), g)
    return g


def save_genotype_to_file(file, genotype):
    """
    Save Genotype object to file, used for submission
    Args:
        file: target text file to store the genotype
        genotype: an Genotype object

    Returns:
        None
    """
    with open(file, 'w') as f:
        f.write(text_format.MessageToString(genotype))

def main():
    PDARTS = LegacyGenotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_5x5', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(3, 6))
    new_PDARTS = convert_legacy_format_to_protobuf(PDARTS)
    save_genotype_to_file('pdarts.txt', new_PDARTS)