from genotype.protoc_pb2 import Genotype, Cell from google.protobuf import text_format from collections import namedtuple __all__ = [ 'convert_legacy_format_to_protobuf', 'load_genotype_from_file', 'save_genotype_to_file', ] LegacyGenotype = namedtuple('LegacyGenotype', 'normal normal_concat reduce reduce_concat') PRIMITIVES = [ 'none', 'max_pool_3x3', 'avg_pool_3x3', 'skip_connect', 'sep_conv_3x3', 'sep_conv_5x5', 'dil_conv_3x3', 'dil_conv_5x5' ] def convert_legacy_format_to_protobuf(legacy_genotype, init_channel=48, layers=14): """ Convert old format NASNET genotype to a protobuf format object Args: legacy_genotype(LegacyGenotype): a named tuple containing normal, normal_concat, reduce, and reduce_concat init_channel: init channel count of the network, default 48 layers: number of cells of the network, default 14 Returns: A protobuf genotype object defining the architecture """ g = Genotype() g.init_channel = init_channel C = init_channel for i in range(layers): if i in [layers // 3, 2 * layers // 3]: reduction = True C *= 2 else: reduction = False if i == 2 * layers // 3: auxiliary = True else: auxiliary = False ops = [] for j, (op_name, indice) in enumerate(legacy_genotype.reduce if reduction else legacy_genotype.normal): ops.append({'frm': indice, 'to': j // 2 + 2, 'type': PRIMITIVES.index(op_name)}) g.cell.append(Cell( id=i, type=reduction, channel=C, num_steps=4, op=ops, concat=[2, 3, 4, 5], auxiliary=auxiliary, )) return g def load_genotype_from_file(file): """ Load Genotype object from text file Args: file: text file defining the architecture following protobuf format Returns: a Genotype object """ g = Genotype() text_format.Parse(open(file, 'r').read(), g) return g def save_genotype_to_file(file, genotype): """ Save Genotype object to file, used for submission Args: file: target text file to store the genotype genotype: an Genotype object Returns: None """ with open(file, 'w') as f: f.write(text_format.MessageToString(genotype)) def main(): PDARTS = LegacyGenotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_5x5', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(3, 6)) new_PDARTS = convert_legacy_format_to_protobuf(PDARTS) save_genotype_to_file('pdarts.txt', new_PDARTS)