# 简介 本用例用于神经网络架构搜索(NAS)领域darts算法在ROCm平台PyTorch框架下的测试,包括架构搜索和架构评估两部分内容,已在rocm2.9 pytorch1.3.0版本下进行验证,测试流程如下。 # 测试流程 ## 准备数据 以[cifar-10](http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz)为例进行测试说明,也可以自行下载PTB和ImageNet数据集。 ## 预训练模型 可以选择在已有的预训练模型上进行训练,下载地址如下: [CIFAR-10](https://drive.google.com/file/d/1Y13i4zKGKgjtWBdC0HWLavjO7wvEiGOc/view?usp=sharing) [PTB](https://drive.google.com/file/d/1Mt_o6fZOlG-VDF3Q5ModgnAJ9W6f_av2/view?usp=sharing) [ImageNet](https://drive.google.com/file/d/1AKr6Y_PoYj7j0Upggyzc26W0RVdg4CVX/view?usp=sharing) ## 运行指令 ### 架构搜索 python3 cnn/train_search.py --batch_size 100 运行结束后会在当前目录下生成./search * /log.txt文件,架构格式如下: genotype = Genotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_5x5', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(2, 6)) ### 格式转化 需使用nasnet/protoc工具将上述得到的genotype转换为protobuf格式,操作如下: cd nasnet/protoc 更改util.py中的main()函数,将架构描述填入LegacyGenotype() ```python def main(): PDARTS = LegacyGenotype(normal=[('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_5x5', 1), ('dil_conv_5x5', 3), ('sep_conv_3x3', 1), ('sep_conv_3x3', 3)], normal_concat=range(2, 6), reduce=[('skip_connect', 0), ('skip_connect', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('skip_connect', 2)], reduce_concat=range(3, 6)) new_PDARTS = convert_legacy_format_to_protobuf(PDARTS) save_genotype_to_file('pdarts.txt', new_PDARTS) ``` 执行如下指令,生成pdarts.txt文件 python3 main.py ### 架构评估 运行示例 cd evaluation ./evaluate.sh {node_name} 1 0 /path/to/pdarts.txt /path/to/{save_dir} # 参考 [https://github.com/quark0/darts](https://github.com/quark0/darts)