from .model import NetworkImageNet from .genotypes import PDARTS, DARTS_V2, NASNet, AmoebaNet from .protoc.utils import load_genotype_from_file, convert_legacy_format_to_protobuf def pdarts(init_channel=48, layers=14): geno = convert_legacy_format_to_protobuf(PDARTS, init_channel=init_channel, layers=layers) return NetworkImageNet(geno) def darts(init_channel=48, layers=14): geno = convert_legacy_format_to_protobuf(DARTS_V2, init_channel=init_channel, layers=layers) return NetworkImageNet(geno) def network(genotype_file): print('=> load genotype from', genotype_file) geno = load_genotype_from_file(genotype_file) return NetworkImageNet(geno)