from .model import NetworkImageNet
from .genotypes import PDARTS, DARTS_V2, NASNet, AmoebaNet
from .protoc.utils import load_genotype_from_file, convert_legacy_format_to_protobuf


def pdarts(init_channel=48, layers=14):
    geno = convert_legacy_format_to_protobuf(PDARTS, init_channel=init_channel, layers=layers)
    return NetworkImageNet(geno)


def darts(init_channel=48, layers=14):
    geno = convert_legacy_format_to_protobuf(DARTS_V2, init_channel=init_channel, layers=layers)
    return NetworkImageNet(geno)


def network(genotype_file):
    print('=> load genotype from', genotype_file)
    geno = load_genotype_from_file(genotype_file)
    return NetworkImageNet(geno)
